1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
{-# LANGUAGE LambdaCase #-}
module Minilang.Lambda.Infer where
import Control.Monad.State (MonadState (..), StateT, evalStateT, modify)
import Data.Bifunctor (first)
import Data.Char (chr, ord)
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Text (Text, pack)
import qualified Minilang.Lambda.Eval as Eval
import Minilang.Lambda.Unify (Type (..), UnifyError, apply, unify)
import Prelude hiding (lookup)
data ATerm ann
= Var Text ann
| Lam Text (ATerm ann) ann
| App (ATerm ann) (ATerm ann) ann
deriving (Show, Eq)
data TypeError
= UnifyError UnifyError
deriving (Show, Eq)
fromTerm :: Eval.Term -> ATerm ()
fromTerm (Eval.Var x) = Var x ()
fromTerm (Eval.Lam x body) = Lam x (fromTerm body) ()
fromTerm (Eval.App a b) = App (fromTerm a) (fromTerm b) ()
infer :: Eval.Term -> Either TypeError Type
infer t = do
annotated <- annotate t
constraints <- collect [annotated] []
subs <- first UnifyError $ unify constraints
pure $ apply subs (typeOf annotated)
collect :: [ATerm Type] -> [(Type, Type)] -> Either TypeError [(Type, Type)]
collect terms constraints = case terms of
[] -> Right constraints
(Var _ _) : rest -> collect rest constraints
(Lam _ body _) : rest -> collect (body : rest) constraints
(App x y t) : rest ->
let (f, b) = (typeOf x, typeOf y)
in collect (x : y : rest) ((f, b :-> t) : constraints)
typeOf :: ATerm Type -> Type
typeOf (Var _ t) = t
typeOf (Lam _ _ t) = t
typeOf (App _ _ t) = t
data Env = Env {nextVar :: Int, typeEnv :: Map.Map Text Type}
deriving (Eq, Show)
annotate :: Eval.Term -> Either TypeError (ATerm Type)
annotate t =
evalStateT (go (fromTerm t) []) newEnv
where
go :: ATerm a -> [(Text, Type)] -> StateT Env (Either TypeError) (ATerm Type)
go t bounds = case t of
Var x _ -> case List.lookup x bounds of
Just ty -> pure $ Var x ty
Nothing ->
lookup x >>= \case
Just ty -> pure $ Var x ty
Nothing -> do
v <- freshVar
bind x v
pure $ Var x v
Lam x body _ -> do
ty <- freshVar
body' <- go body ((x, ty) : bounds)
pure $ Lam x body' (ty :-> typeOf body')
App a b _ ->
App <$> go a bounds <*> go b bounds <*> freshVar
newEnv :: Env
newEnv = Env 0 Map.empty
bind :: Text -> Type -> StateT Env (Either TypeError) ()
bind x ty = do
modify (\(Env n env) -> Env n (Map.insert x ty env))
lookup :: Text -> StateT Env (Either TypeError) (Maybe Type)
lookup x = do
Env _ env <- get
pure $ Map.lookup x env
freshVar :: (Monad m) => StateT Env m Type
freshVar = do
get >>= \(Env n _) ->
modify (\(Env n' env') -> Env (n' + 1) env') >> pure (TyVar $ nameOf n)
where
nameOf n
| n < 26 = pack [chr (n + ord 'a')]
| otherwise =
let (d, r) = n `divMod` 26
in pack $ chr (r + ord 'a') : show d
|