summaryrefslogtreecommitdiff
path: root/lambda-calcul/haskell/src/Minilang/Lambda
diff options
context:
space:
mode:
authorArnaud Bailly <arnaud@pankzsoft.com>2025-10-13 09:27:07 +0200
committerArnaud Bailly <arnaud@pankzsoft.com>2025-10-13 09:27:07 +0200
commit3a67e69bfe9492d2a2fc5e4b07cc8c909a346064 (patch)
tree8b4b0281dffe166f5c8495b1a5b892c1f8285870 /lambda-calcul/haskell/src/Minilang/Lambda
parent21befc8c8ab2e91632f5341b4fa9425cf3c815ff (diff)
downloadlambda-nantes-3a67e69bfe9492d2a2fc5e4b07cc8c909a346064.tar.gz
add minimal evaluator and type inference
Diffstat (limited to 'lambda-calcul/haskell/src/Minilang/Lambda')
-rw-r--r--lambda-calcul/haskell/src/Minilang/Lambda/Eval.hs36
-rw-r--r--lambda-calcul/haskell/src/Minilang/Lambda/Infer.hs97
-rw-r--r--lambda-calcul/haskell/src/Minilang/Lambda/Unify.hs57
3 files changed, 190 insertions, 0 deletions
diff --git a/lambda-calcul/haskell/src/Minilang/Lambda/Eval.hs b/lambda-calcul/haskell/src/Minilang/Lambda/Eval.hs
new file mode 100644
index 0000000..68b01be
--- /dev/null
+++ b/lambda-calcul/haskell/src/Minilang/Lambda/Eval.hs
@@ -0,0 +1,36 @@
+module Minilang.Lambda.Eval where
+
+import Data.Text (Text)
+
+data Term
+ = Var Text
+ | Lam Text Term
+ | App Term Term
+ deriving (Show, Eq)
+
+type Env = [(Text, Term)]
+
+-- call-by-value evaluator
+eval :: Term -> Env -> Term
+eval (Var x) env = case lookup x env of
+ Just v -> v -- we do not need to eval v again
+ Nothing -> Var x
+eval (Lam x body) _env = Lam x body
+eval (App f a) env =
+ -- we need to force evaluation of the argument
+ -- because haskell's default semantics is non-strict
+ -- so if a' is never used, it will not be evaluated!
+ let a' = eval a env
+ in seq a' $ case eval f env of
+ Lam x body -> eval body ((x, a') : env)
+ f' -> App f' a'
+
+evalNeed :: Term -> Env -> Term
+evalNeed (Var x) env = case lookup x env of
+ Just v -> evalNeed v env -- we need to eval v because it might be a redex
+ Nothing -> Var x
+evalNeed (Lam x body) _env = Lam x body
+evalNeed (App f a) env =
+ case evalNeed f env of
+ Lam x body -> evalNeed body ((x, a) : env)
+ f' -> App f' a
diff --git a/lambda-calcul/haskell/src/Minilang/Lambda/Infer.hs b/lambda-calcul/haskell/src/Minilang/Lambda/Infer.hs
new file mode 100644
index 0000000..415849a
--- /dev/null
+++ b/lambda-calcul/haskell/src/Minilang/Lambda/Infer.hs
@@ -0,0 +1,97 @@
+{-# LANGUAGE LambdaCase #-}
+
+module Minilang.Lambda.Infer where
+
+import Control.Monad.State (MonadState (..), StateT, evalStateT, modify)
+import Data.Bifunctor (first)
+import Data.Char (chr, ord)
+import qualified Data.List as List
+import qualified Data.Map as Map
+import Data.Text (Text, pack)
+import qualified Minilang.Lambda.Eval as Eval
+import Minilang.Lambda.Unify (Type (..), UnifyError, apply, unify)
+import Prelude hiding (lookup)
+
+data ATerm ann
+ = Var Text ann
+ | Lam Text (ATerm ann) ann
+ | App (ATerm ann) (ATerm ann) ann
+ deriving (Show, Eq)
+
+data TypeError
+ = UnifyError UnifyError
+ deriving (Show, Eq)
+
+fromTerm :: Eval.Term -> ATerm ()
+fromTerm (Eval.Var x) = Var x ()
+fromTerm (Eval.Lam x body) = Lam x (fromTerm body) ()
+fromTerm (Eval.App a b) = App (fromTerm a) (fromTerm b) ()
+
+infer :: Eval.Term -> Either TypeError Type
+infer t = do
+ annotated <- annotate t
+ constraints <- collect [annotated] []
+ subs <- first UnifyError $ unify constraints
+ pure $ apply subs (typeOf annotated)
+
+collect :: [ATerm Type] -> [(Type, Type)] -> Either TypeError [(Type, Type)]
+collect terms constraints = case terms of
+ [] -> Right constraints
+ (Var _ _) : rest -> collect rest constraints
+ (Lam _ body _) : rest -> collect (body : rest) constraints
+ (App x y t) : rest ->
+ let (f, b) = (typeOf x, typeOf y)
+ in collect (x : y : rest) ((f, b :-> t) : constraints)
+
+typeOf :: ATerm Type -> Type
+typeOf (Var _ t) = t
+typeOf (Lam _ _ t) = t
+typeOf (App _ _ t) = t
+
+data Env = Env {nextVar :: Int, typeEnv :: Map.Map Text Type}
+ deriving (Eq, Show)
+
+annotate :: Eval.Term -> Either TypeError (ATerm Type)
+annotate t =
+ evalStateT (go (fromTerm t) []) newEnv
+ where
+ go :: ATerm a -> [(Text, Type)] -> StateT Env (Either TypeError) (ATerm Type)
+ go t bounds = case t of
+ Var x _ -> case List.lookup x bounds of
+ Just ty -> pure $ Var x ty
+ Nothing ->
+ lookup x >>= \case
+ Just ty -> pure $ Var x ty
+ Nothing -> do
+ v <- freshVar
+ bind x v
+ pure $ Var x v
+ Lam x body _ -> do
+ ty <- freshVar
+ body' <- go body ((x, ty) : bounds)
+ pure $ Lam x body' (ty :-> typeOf body')
+ App a b _ ->
+ App <$> go a bounds <*> go b bounds <*> freshVar
+
+newEnv :: Env
+newEnv = Env 0 Map.empty
+
+bind :: Text -> Type -> StateT Env (Either TypeError) ()
+bind x ty = do
+ modify (\(Env n env) -> Env n (Map.insert x ty env))
+
+lookup :: Text -> StateT Env (Either TypeError) (Maybe Type)
+lookup x = do
+ Env _ env <- get
+ pure $ Map.lookup x env
+
+freshVar :: (Monad m) => StateT Env m Type
+freshVar = do
+ get >>= \(Env n _) ->
+ modify (\(Env n' env') -> Env (n' + 1) env') >> pure (TyVar $ nameOf n)
+ where
+ nameOf n
+ | n < 26 = pack [chr (n + ord 'a')]
+ | otherwise =
+ let (d, r) = n `divMod` 26
+ in pack $ chr (r + ord 'a') : show d
diff --git a/lambda-calcul/haskell/src/Minilang/Lambda/Unify.hs b/lambda-calcul/haskell/src/Minilang/Lambda/Unify.hs
new file mode 100644
index 0000000..ad06552
--- /dev/null
+++ b/lambda-calcul/haskell/src/Minilang/Lambda/Unify.hs
@@ -0,0 +1,57 @@
+-- ported from https://www.cs.cornell.edu/courses/cs3110/2011sp/Lectures/lec26-type-inference/type-inference.htm
+module Minilang.Lambda.Unify where
+
+import Data.Text (Text)
+
+type Identifier = Text
+
+data Type
+ = TyVar Text
+ | Type :-> Type
+ deriving (Show, Eq)
+
+infixr 5 :->
+
+-- invariant: no identifier on the left hand-side can appear in
+-- an earlier term in the list, ie. the list is in "dependency order"
+-- and forms a DAG.
+type Substitution = [(Identifier, Type)]
+
+occurs :: Identifier -> Type -> Bool
+occurs x (TyVar y) = x == y
+occurs x (f :-> g) = occurs x f || occurs x g
+
+subst :: Type -> Identifier -> Type -> Type
+subst s x t@(TyVar y)
+ | x == y = s
+ | otherwise = t
+subst s x (f :-> g) = subst s x f :-> subst s x g
+
+apply :: Substitution -> Type -> Type
+apply subs t =
+ foldr (\(x, s) -> subst s x) t subs
+
+data UnifyError
+ = MismatchHead Text Text [Type] [Type]
+ | Circularity Identifier Type
+ deriving (Show, Eq)
+
+unifyOne :: Type -> Type -> Either UnifyError Substitution
+unifyOne (TyVar x) (TyVar y) = if x == y then Right [] else Right [(x, TyVar y)]
+unifyOne (f :-> g) (f' :-> g') =
+ unify [(f, f'), (g, g')]
+unifyOne (TyVar x) t =
+ if occurs x t
+ then Left $ Circularity x t
+ else Right [(x, t)]
+unifyOne t (TyVar x) =
+ if occurs x t
+ then Left $ Circularity x t
+ else Right [(x, t)]
+
+unify :: [(Type, Type)] -> Either UnifyError Substitution
+unify [] = Right []
+unify ((s, t) : rest) = do
+ sub2 <- unify rest
+ sub1 <- unifyOne (apply sub2 s) (apply sub2 t)
+ return $ sub1 ++ sub2