{-# LANGUAGE LambdaCase #-} module Minilang.Lambda.Infer where import Control.Monad.State (MonadState (..), StateT, evalStateT, modify) import Data.Bifunctor (first) import Data.Char (chr, ord) import qualified Data.List as List import qualified Data.Map as Map import Data.Text (Text, pack) import qualified Minilang.Lambda.Eval as Eval import Minilang.Lambda.Unify (Type (..), UnifyError, apply, unify) import Prelude hiding (lookup) data ATerm ann = Var Text ann | Lam Text (ATerm ann) ann | App (ATerm ann) (ATerm ann) ann deriving (Show, Eq) data TypeError = UnifyError UnifyError deriving (Show, Eq) fromTerm :: Eval.Term -> ATerm () fromTerm (Eval.Var x) = Var x () fromTerm (Eval.Lam x body) = Lam x (fromTerm body) () fromTerm (Eval.App a b) = App (fromTerm a) (fromTerm b) () infer :: Eval.Term -> Either TypeError Type infer t = do annotated <- annotate t constraints <- collect [annotated] [] subs <- first UnifyError $ unify constraints pure $ apply subs (typeOf annotated) collect :: [ATerm Type] -> [(Type, Type)] -> Either TypeError [(Type, Type)] collect terms constraints = case terms of [] -> Right constraints (Var _ _) : rest -> collect rest constraints (Lam _ body _) : rest -> collect (body : rest) constraints (App x y t) : rest -> let (f, b) = (typeOf x, typeOf y) in collect (x : y : rest) ((f, b :-> t) : constraints) typeOf :: ATerm Type -> Type typeOf (Var _ t) = t typeOf (Lam _ _ t) = t typeOf (App _ _ t) = t data Env = Env {nextVar :: Int, typeEnv :: Map.Map Text Type} deriving (Eq, Show) annotate :: Eval.Term -> Either TypeError (ATerm Type) annotate t = evalStateT (go (fromTerm t) []) newEnv where go :: ATerm a -> [(Text, Type)] -> StateT Env (Either TypeError) (ATerm Type) go t bounds = case t of Var x _ -> case List.lookup x bounds of Just ty -> pure $ Var x ty Nothing -> lookup x >>= \case Just ty -> pure $ Var x ty Nothing -> do v <- freshVar bind x v pure $ Var x v Lam x body _ -> do ty <- freshVar body' <- go body ((x, ty) : bounds) pure $ Lam x body' (ty :-> typeOf body') App a b _ -> App <$> go a bounds <*> go b bounds <*> freshVar newEnv :: Env newEnv = Env 0 Map.empty bind :: Text -> Type -> StateT Env (Either TypeError) () bind x ty = do modify (\(Env n env) -> Env n (Map.insert x ty env)) lookup :: Text -> StateT Env (Either TypeError) (Maybe Type) lookup x = do Env _ env <- get pure $ Map.lookup x env freshVar :: (Monad m) => StateT Env m Type freshVar = do get >>= \(Env n _) -> modify (\(Env n' env') -> Env (n' + 1) env') >> pure (TyVar $ nameOf n) where nameOf n | n < 26 = pack [chr (n + ord 'a')] | otherwise = let (d, r) = n `divMod` 26 in pack $ chr (r + ord 'a') : show d