OK, so now that I've wowed you with my awesome code, let's talk about what's going on. Type inference breaks down into essentially 2 components
- Constraint Generation
- Unification
We inspect the program we're trying to infer a type for and generate a bunch of statements (constraints) which are of the form
This type is equal to this type
These types have "unification variables" in them. These aren't normal ML type variables. They're generated by the compiler, for the compiler, and will eventually be filled in with either
- A rigid polymorphic variable
- A normal concrete type
For example, if we're looking at the expression
f a
We first just say that f : 'f
where 'f
is one of those unification
variables I mentioned. Next we say that a : 'a
. Since we're apply
f
to a
we can generate the constraints that
'f ~ 'x -> 'y
'a ~ 'x
We then unify these constraints to produce f : 'a -> 'x
and a : 'a
. We'd then using the surrounding constraints to produce more
information about what exactly 'a
and 'x
might be.
Now onto some specifics
In order to actually talk about type inference we first have to define our language. We have the abstract syntax tree:
type tvar = int
local val freshSource = ref 0 in
fun fresh () : tvar =
!freshSource before freshSource := !freshSource + 1
end
datatype monotype = TBool
| TArr of monotype * monotype
| TVar of tvar
datatype polytype = PolyType of int list * monotype
datatype exp = True
| False
| Var of int
| App of exp * exp
| Let of exp * exp
| Fn of exp
| If of exp * exp * exp
First we have type variables which are globally unique integers. To
give us a method for actually producing them we have fresh
which
uses a ref-cell to never return the same result twice. From there we
have mono-types. These are normal ML types without any
polymorphism. There are type variables, booleans, and
functions. Polytypes are just monotypes with an extra forall
at the
front. This is where we get polymorphism from. A polytype binds a
number of type variables, stored in this representation as an int
list.
Finally, we have expressions. Aside form the normal constants, we have
variables, lambdas, applications, and if. The way we represent
variables here is with DeBruijn variables. A variable is a number that
tells you how many binders are between it and where it was bound. For
example, const
would be written Fn (Fn (Var 1))
in this
representation.
With this in mind we define some helpful utility functions. When type checking, we have a context full of information. The two facts we know are
datatype info = PolyTypeVar of polytype
| MonoTypeVar of monotype
type context = info list
Where the ith element of a context indicates the piece of information we know about the ith DeBruijn variable. We'll also need to substitute a type variable for a type. We also want to be able to find out all the free variables in a type.
fun subst ty' var ty =
case ty of
TVar var' => if var = var' then ty' else TVar var'
| TArr (l, r) => TArr (subst ty' var l, subst ty' var r)
| TBool => TBool
fun freeVars t =
case t of
TVar v => [v]
| TArr (l, r) => freeVars l @ freeVars r
| TBool => []
Both of these functions just recurse over types and do some work at
the variable case. Note that freeVars
can contain duplicates, this
turns out not to be important in all cases except one:
generalizeMonoType
. The basic idea is that given a monotype with a
bunch of unification variables and a surrounding context, figure out
which variables can be bound up in a polymorphic type. If they don't
appear in the surrounding context, we generalize them by binding them
in a new poly type's forall spot.
fun dedup [] = []
| dedup (x :: xs) =
if List.exists (fn y => x = y) xs
then dedup xs
else x :: dedup xs
fun generalizeMonoType ctx ty =
let fun notMem xs x = List.all (fn y => x <> y) xs
fun free (MonoTypeVar m) = freeVars m
| free (PolyTypeVar (PolyType (bs, m))) =
List.filter (notMem bs) (freeVars m)
val ctxVars = List.concat (List.map free ctx)
val polyVars = List.filter (notMem ctxVars) (freeVars ty)
in PolyType (dedup polyVars, ty) end
Here the bulk of the code is deciding whether or not a variable is
free in the surrounding context using free
. It looks at a piece of
info to determine what variables occur in it. We then accumulate all of
these variables into cxtVars
and use this list to decide what to
generalize.
Next we need to take a polytype to a monotype. This is the
specialization of a polymorphic type that we love and use when we use
map
on a function from int -> double
. This works by taking each
bound variable and replacing it with a fresh unification
variables. This is nicely handled by folds!
fun mintNewMonoType (PolyType (ls, ty)) =
foldl (fn (v, t) => subst (TVar (fresh ())) v t) ty ls
Last but not least, we have a function to take a context and a variable and give us a monotype which corresponds to it. This may produce a new monotype if we think the variable has a polytype.
exception UnboundVar of int
fun lookupVar var ctx =
case List.nth (ctx, var) handle Subscript => raise UnboundVar var of
PolyTypeVar pty => mintNewMonoType pty
| MonoTypeVar mty => mty
For the sake of nice error messages, we also throw UnboundVar
instead of just subscript in the error case. Now that we've gone
through all of the utility functions, on to unification!
A large part of this program is basically "I'll give you a list of constraints and you give me the solution". The program to solve these proceeds by pattern matching on the constraints.
In the empty case, we have no constraints so we give back the empty solution.
fun unify [] = []
In the next case we actually have to look at what constraint we're trying to solve.
| unify (c :: constrs) =
case c of
If we're lucky, we're just trying to unify TBool
with TBool
, this
does nothing since these types have no variables and are equal. In
this case we just recurse.
(TBool, TBool) => unify constrs
If we've got two function types, we just constrain their domains and ranges to be the same and continue on unifying things.
| (TArr (l, r), TArr (l', r')) => unify ((l, l') :: (r, r') :: constrs)
Now we have to deal with finding a variable. We definitely want to
avoid adding (TVar v, TVar v)
to our solution, so we'll have a
special case for trying to unify two variables.
| (TVar i, TVar j) =>
if i = j
then unify constrs
else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
This is our first time actually adding something to our solution so
there's several new elements here. The first is this function
addSol
. It's defined as
fun addSol v ty sol = (v, applySol sol ty) :: sol
So in order to make sure our solution is internally consistent it's
important that whenever we add a type to our solution we first apply
the solution to it. This ensures that we can substitute a variable in
our solution for its corresponding type and not worry about whether we
need to do something further. Additionally, whenever we add a new
binding we substitute for it in the constraints we have left to ensure
we never have a solution which is just inconsistent. This prevents us
from unifying v ~ TBool
and v ~ TArr(TBool, TBool)
in the same
solution! The actual code for doing this is that
substConstr (TVar j) i constrs
bit.
The next case is the general case for unifying a variable with some type. It looks very similar to this one.
| ((TVar i, ty) | (ty, TVar i)) =>
if occursIn i ty
then raise UnificationError c
else addSol i ty (unify (substConstrs ty i constrs))
Here we have the critical occursIn
check. This checks to see if a
variable appears in a type and prevents us from making erroneous
unifications like TVar a ~ TArr (TVar a, TVar a)
. This occurs check
is actually very easy to implement
fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)
Finally we have one last case: the failure case. This is the catch-all case for if we try to unify two things that are obviously incompatible.
| _ => raise UnificationError c
All together, that code was
fun applySol sol ty =
foldl (fn ((v, ty), ty') => subst ty v ty') ty sol
fun applySolCxt sol cxt =
let fun applyInfo i =
case i of
PolyTypeVar (PolyType (bs, m)) =>
PolyTypeVar (PolyType (bs, (applySol sol m)))
| MonoTypeVar m => MonoTypeVar (applySol sol m)
in map applyInfo cxt end
fun addSol v ty sol = (v, applySol sol ty) :: sol
fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)
fun unify ([] : constr list) : sol = []
| unify (c :: constrs) =
case c of
(TBool, TBool) => unify constrs
| (TVar i, TVar j) =>
if i = j
then unify constrs
else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
| ((TVar i, ty) | (ty, TVar i)) =>
if occursIn i ty
then raise UnificationError c
else addSol i ty (unify (substConstrs ty i constrs))
| (TArr (l, r), TArr (l', r')) =>
unify ((l, l') :: (r, r') :: constrs)
| _ => raise UnificationError c
The other half of this algorithm is the constraint generation part. We
generate constraints and use unify
to turn them into solutions. This
boils down to two functoins. The first is to glue together solutions.
fun <+> (sol1, sol2) =
let fun notInSol2 v = List.all (fn (v', _) => v <> v') sol2
val sol1' = List.filter (fn (v, _) => notInSol2 v) sol1
in
map (fn (v, ty) => (v, applySol sol1 ty)) sol2 @ sol1'
end
infixr 3 <+>
Given two solutions we figure out which things don't occur in the in
the second solution. Next, we apply solution 1 everywhere in the
second solution, giving a consistent solution wihch contains
everything in sol2
, finally we add in all the stuff not in sol2
but in sol1
. This doesn't check to make sure that the solutions are
actually consistent, this is done elsewhere.
Next is the main function here constrain
. This actually generates
solution and type given a context and an expression. The first few
cases are nice and simple
fun constrain ctx True = (TBool, [])
| constrain ctx False = (TBool, [])
| constrain ctx (Var i) = (lookupVar i ctx, [])
In these cases we don't infer any constraints, we just figure out
types based on information we know previously. Next for Fn
we
generate a fresh variable to represent the arguments type and just
constrain the body.
| constrain ctx (Fn body) =
let val argTy = TVar (fresh ())
val (rTy, sol) = constrain (MonoTypeVar argTy :: ctx) body
in (TArr (applySol sol argTy, rTy), sol) end
Once we have the solution for the body, we apply it to the argument
type which might replace it with a concrete type if the constraints we
inferred for the body demand it. For If
we do something similar
except we add a few constraints of our own to solve.
| constrain ctx (If (i, t, e)) =
let val (iTy, sol1) = constrain ctx i
val (tTy, sol2) = constrain (applySolCxt sol1 ctx) t
val (eTy, sol3) = constrain (applySolCxt (sol1 <+> sol2) ctx) e
val sol = sol1 <+> sol2 <+> sol3
val sol = sol <+> unify [ (applySol sol iTy, TBool)
, (applySol sol tTy, applySol sol eTy)]
in
(tTy, sol)
end
Notice how we apply each solution to the context for the next thing
we're constraining. This is how we ensure that each solution will be
consistent. Once we've generated solutions to the constraints in each
of the subterms, we smash them together to produce the first
solution. Next, we ensure that the subcomponents have the right type
by generating a few constraints to ensure that iTy
is a bool and
that tTy
and eTy
(the types of the branches) are both the same. We
have to carefully apply the sol
to each of these prior to unifying
them to make sure our solution stays consistent.
This is practically the same as what the App
case is
| constrain ctx (App (l, r)) =
let val (domTy, ranTy) = (TVar (fresh ()), TVar (fresh ()))
val (funTy, sol1) = constrain ctx l
val (argTy, sol2) = constrain (applySolCxt sol1 ctx) r
val sol = sol1 <+> sol2
val sol = sol <+> unify [(applySol sol funTy,
applySol sol (TArr (domTy, ranTy)))
, (applySol sol argTy, applySol sol domTy)]
in (ranTy, sol) end
The only real difference here is that we generate different constraints: we make sure we're applying a function whose domain is the same as the argument type.
The most interesting case here is Let
. This implements let
generalization which is how we actually get polymorphism. After
inferring the type of the thing we're binding we generalize it, giving
us a poly type to use in the body of let. The key to generalizing it
is that generalizeMonoType
we had before.
| constrain ctx (Let (e, body)) =
let val (eTy, sol1) = constrain ctx e
val ctx' = applySolCxt sol1 ctx
val eTy' = generalizeMonoType ctx' (applySol sol1 eTy)
val (rTy, sol2) = constrain (PolyTypeVar eTy' :: ctx') body
in (rTy, sol1 <+> sol2) end
We do pretty much everything we had before except now we carefully ensure to apply the solution we get for the body to the context and then to generalize the type with respect to that new context. This is how we actually get polymorphism, it will assign a proper polymorphic type to the argument.
That wraps up constraint generation. Now all that's left to see if the overall driver for type inference.
fun infer e =
let val (ty, sol) = constrain [] e
in generalizeMonoType [] (applySol sol ty) end
end
So all we do is infer and generalize a type! And there you have it, that's how ML and Haskell do type inference.