summaryrefslogtreecommitdiff
path: root/lambda-calcul/haskell/src/Minilang/Lambda/Infer.hs
blob: 415849a94208d87165a4c1a3415a812c097d3780 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
{-# LANGUAGE LambdaCase #-}

module Minilang.Lambda.Infer where

import Control.Monad.State (MonadState (..), StateT, evalStateT, modify)
import Data.Bifunctor (first)
import Data.Char (chr, ord)
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Text (Text, pack)
import qualified Minilang.Lambda.Eval as Eval
import Minilang.Lambda.Unify (Type (..), UnifyError, apply, unify)
import Prelude hiding (lookup)

data ATerm ann
  = Var Text ann
  | Lam Text (ATerm ann) ann
  | App (ATerm ann) (ATerm ann) ann
  deriving (Show, Eq)

data TypeError
  = UnifyError UnifyError
  deriving (Show, Eq)

fromTerm :: Eval.Term -> ATerm ()
fromTerm (Eval.Var x) = Var x ()
fromTerm (Eval.Lam x body) = Lam x (fromTerm body) ()
fromTerm (Eval.App a b) = App (fromTerm a) (fromTerm b) ()

infer :: Eval.Term -> Either TypeError Type
infer t = do
  annotated <- annotate t
  constraints <- collect [annotated] []
  subs <- first UnifyError $ unify constraints
  pure $ apply subs (typeOf annotated)

collect :: [ATerm Type] -> [(Type, Type)] -> Either TypeError [(Type, Type)]
collect terms constraints = case terms of
  [] -> Right constraints
  (Var _ _) : rest -> collect rest constraints
  (Lam _ body _) : rest -> collect (body : rest) constraints
  (App x y t) : rest ->
    let (f, b) = (typeOf x, typeOf y)
     in collect (x : y : rest) ((f, b :-> t) : constraints)

typeOf :: ATerm Type -> Type
typeOf (Var _ t) = t
typeOf (Lam _ _ t) = t
typeOf (App _ _ t) = t

data Env = Env {nextVar :: Int, typeEnv :: Map.Map Text Type}
  deriving (Eq, Show)

annotate :: Eval.Term -> Either TypeError (ATerm Type)
annotate t =
  evalStateT (go (fromTerm t) []) newEnv
  where
    go :: ATerm a -> [(Text, Type)] -> StateT Env (Either TypeError) (ATerm Type)
    go t bounds = case t of
      Var x _ -> case List.lookup x bounds of
        Just ty -> pure $ Var x ty
        Nothing ->
          lookup x >>= \case
            Just ty -> pure $ Var x ty
            Nothing -> do
              v <- freshVar
              bind x v
              pure $ Var x v
      Lam x body _ -> do
        ty <- freshVar
        body' <- go body ((x, ty) : bounds)
        pure $ Lam x body' (ty :-> typeOf body')
      App a b _ ->
        App <$> go a bounds <*> go b bounds <*> freshVar

newEnv :: Env
newEnv = Env 0 Map.empty

bind :: Text -> Type -> StateT Env (Either TypeError) ()
bind x ty = do
  modify (\(Env n env) -> Env n (Map.insert x ty env))

lookup :: Text -> StateT Env (Either TypeError) (Maybe Type)
lookup x = do
  Env _ env <- get
  pure $ Map.lookup x env

freshVar :: (Monad m) => StateT Env m Type
freshVar = do
  get >>= \(Env n _) ->
    modify (\(Env n' env') -> Env (n' + 1) env') >> pure (TyVar $ nameOf n)
  where
    nameOf n
      | n < 26 = pack [chr (n + ord 'a')]
      | otherwise =
          let (d, r) = n `divMod` 26
           in pack $ chr (r + ord 'a') : show d