8/31/07

Generalizing Hindley-Milner Type Inference Algorithms

This is a small ocaml experiment which follows "Generalizing Hindley-Milner Type Inference Algorithms".

Please, expect some bugs. It is a small study for my forthcoming -polite cough- language, which has a more elaborate sequent-calculus based type inferencer.


(********************************************************************)

let string_eq: string -> string -> bool =
function s0 -> function s1 -> compare s0 s1 == 0

(********************************************************************)

let rec list_concat: 'a list -> 'a list -> 'a list =
function xx -> function yy -> match xx with
| [] -> yy
| (x::xx) -> x::(list_concat xx yy)

let list_prepend: 'a -> 'a list -> 'a list =
function x -> function yy -> list_concat [x] yy

let list_postpend: 'a list -> 'a -> 'a list =
function xx -> function y -> list_concat xx [y]

let rec list_map: ('a -> 'b) -> 'a list -> 'b list =
function f -> function xx -> match xx with
| [] -> []
| (x::xx) ->
let y = f x in
let yy = list_map f xx in
y::yy

let rec list_member: ('a -> 'a -> bool) -> 'a list -> 'a -> bool =
function eq ->
function xx -> function e -> match xx with
| [] -> false
| (x::xx) -> if eq e x then true else list_member eq xx e


let rec list_union: ('a -> 'a -> bool) -> 'a list -> 'a list -> 'a list =
function eq ->
function xx -> function yy -> match xx with
| [] -> yy
| (x::xx) -> if list_member eq yy x then (list_union eq xx yy)
else x::(list_union eq xx yy)

let rec list_intersection: ('a -> 'a -> bool) -> 'a list -> 'a list -> 'a list =
function eq ->
function xx -> function yy -> match xx with
| [] -> []
| (x::xx) -> if list_member eq yy x then x::(list_intersection eq xx yy)
else (list_intersection eq xx yy)

let rec list_difference: ('a -> 'a -> bool) -> 'a list -> 'a list -> 'a list =
function eq ->
function xx -> function yy -> match xx with
| [] -> []
| (x::xx) -> if list_member eq yy x then (list_difference eq xx yy)
else x::(list_difference eq xx yy)

let rec list_subset: ('a -> 'a -> bool) -> 'a list -> 'a list -> bool =
function eq ->
function xx -> function yy -> match xx with
| [] -> true
| (x::xx) -> if list_member eq yy x then list_subset eq xx yy
else false

let rec list_eq: ('a -> 'a -> bool) -> 'a list -> 'a list -> bool =
function eq ->
function xx -> function yy ->
(list_subset eq xx yy) && (list_subset eq yy xx)

let rec list_flatten: 'a list list -> 'a list =
function ll -> match ll with
| [] -> []
| l::ll -> list_concat l (list_flatten ll)

let (++) = list_concat
let (+<) = list_postpend
let (>+) = list_prepend

(********************************************************************)

let list_domain: ('a * 'b) list -> 'a list =
function xx -> list_map (fun (x,y) -> x) xx

let list_range: ('a * 'b) list -> 'b list =
function xx -> list_map (fun (x,y) -> y) xx

(********************************************************************)

let rec flatten_strings: string list -> string =
function ss -> match ss with
| [] -> ""
| (s::[]) -> s
| (s::ss) -> s ^ ", " ^ (flatten_strings ss)

(********************************************************************)

type xp =
| Var of string
| App of xp * xp
| Abs of string * xp
| Let of string * xp * xp

let rec xp_to_string: xp -> string =
function e -> match e with
| Var(s0) -> s0
| App(e0,e1) ->
let s0 = xp_to_string e0 in
let s1 = xp_to_string e1 in
"(" ^ s0 ^ " " ^ s1 ^ ")"
| Abs(s0,e1) ->
let s1 = xp_to_string e1 in
"(\\" ^ s0 ^ " -> " ^ s1 ^ ")"
| Let(s0,e1,e2) ->
let s1 = xp_to_string e1 in
let s2 = xp_to_string e2 in
"(let " ^ s0 ^ " = " ^ s1 ^ " in " ^ s2 ^ ")"

let rec xp_eq: xp -> xp -> bool =
function e0 -> function e1 -> match e0,e1 with
| Var(s0), Var(s1) -> compare s0 s1 == 0
(* XXX *)
| _, _ -> false

(********************************************************************)

type tp =
| TVar of string
| TConst of string
| TArrow of tp * tp
| TScheme of tp list * tp

let rec tp_to_string: tp -> string =
function t -> match t with
| TVar(s) -> s
| TConst(s) -> s
| TArrow(t0,t1) ->
let s0 = tp_to_string t0 in
let s1 = tp_to_string t1 in
"(" ^ s0 ^ " -> " ^ s1 ^ ")"
| TScheme(tt, t) ->
let s0 = flatten_strings (list_map tp_to_string tt) in
let s1 = tp_to_string t in
"(A " ^ s0 ^ " . " ^ s1 ^ ")"

let rec tps_to_string: tp list -> string =
function tt -> match tt with
| [] -> ""
| [t] -> tp_to_string t
| (t::tt) -> tp_to_string t ^ ", " ^ tps_to_string tt

let rec tp_eq: tp -> tp -> bool =
function t0 -> function t1 -> match (t0, t1) with
| TVar(s0), TVar(s1) -> compare s0 s1 == 0
| TConst(s0), TConst(s1) -> compare s0 s1 == 0
| TArrow(t0,t1), TArrow(t2,t3) -> tp_eq t0 t2 && tp_eq t1 t3
| TScheme(tt0,t1), TScheme(tt2,t3) -> false
| _, _ -> false

let rec tp_freevars: tp -> tp list =
function tp -> match tp with
| TVar(s) -> [tp]
| TConst(s) -> []
| TArrow(t0,t1) -> list_union tp_eq (tp_freevars t0) (tp_freevars t1)
| TScheme(tt, t) -> list_difference tp_eq (tp_freevars t) tt

(********************************************************************)

type subs = (tp * tp) list

let sub_to_string: (tp * tp) -> string =
function (t0, t1) ->
"[" ^ tp_to_string t0 ^ " := " ^ tp_to_string t1 ^ "]"

let rec tp_sub_one: tp -> tp -> tp -> tp =
function t0 -> function t1 -> function t2 ->
if tp_eq t0 t2 then t1 else
match t2 with
| TVar(s1) -> t2
| TConst(s) -> t2
| TArrow(tl,tr) ->
TArrow(tp_sub_one t0 t1 tl, tp_sub_one t0 t1 tr)
| TScheme(tt, t) ->
if list_member tp_eq tt t0 then t2
else TScheme(tt, tp_sub_one t0 t1 t)

let rec tp_sub: subs -> tp -> tp =
function ss -> function t0 -> match ss with
| [] -> t0
| ((s,t)::ss) -> tp_sub_one s t (tp_sub ss t0)

let sub_sub: (tp * tp) -> (tp * tp) -> (tp * tp) =
function (t0, t1) -> function (t2,t3) ->
(tp_sub_one t0 t1 t2, tp_sub_one t0 t1 t3)

let rec subs_sub: subs -> subs =
function ss -> match ss with
| [] -> []
| (s::ss) -> (s::list_map (sub_sub s) (subs_sub ss))

(********************************************************************)

let free_tick: int ref = ref 0

let tp_fresh_var: unit -> tp =
function _ ->
let n = !free_tick in
(free_tick := n + 1; TVar("v#" ^ (string_of_int n)))

let generalize_free: tp -> tp =
function t -> TScheme(tp_freevars t, t)

let generalize: tp list -> tp -> tp =
function tt -> function t ->
let tt = list_difference tp_eq (tp_freevars t) tt in
if list_eq tp_eq tt [] then t else
TScheme(tt, t)

let rec instantiate: tp -> tp =
function t -> match t with
| TScheme([], t) -> t
| TScheme(t0::tt, t) ->
instantiate (TScheme(tt, tp_sub [(t0, tp_fresh_var())] t))
| _ -> t

(********************************************************************)

exception XXX of string

let rec tp_occurs: tp -> tp -> bool =
function t0 -> function t1 ->
if tp_eq t0 t1 then true else
match t1 with
| TVar(s1) -> false
| TConst(s) -> false
| TArrow(tl,tr) -> tp_occurs t0 tl || tp_occurs t0 tr
| TScheme(tt, t) ->
if list_member tp_eq tt t0 then false
else tp_occurs t0 t

let rec tp_mgu: tp -> tp -> subs =
function t0 -> function t1 -> match (t0,t1) with
| TVar(s), _ ->
if tp_occurs t0 t1 then
raise (XXX("occurs check failed on " ^ tp_to_string t0 ^ " " ^ tp_to_string t1))
else
[(t0, t1)]
| _, TVar(s) ->
tp_mgu t1 t0
| TConst(s0), TConst(s1) ->
if string_eq s0 s1 then [] else
raise (XXX("cannot unify constants " ^ tp_to_string t0
^ " " ^ tp_to_string t1))
| TArrow(t0, t1), TArrow(t2, t3) ->
let ss0 = tp_mgu t0 t2 in
let ta = tp_sub ss0 t1 in
let tb = tp_sub ss0 t3 in
let ss1 = tp_mgu ta tb in
ss0 ++ ss1
| _, _ ->
raise (XXX("no unification rule for " ^ tp_to_string t0
^ " " ^ tp_to_string t1))

(********************************************************************)

type prop = Eq of tp * tp
| Sub of tp * tp list * tp
| In of tp * tp


let prop_to_string: prop -> string =
function p -> let t = (match p with
| Eq(t0,t1) ->
tp_to_string t0 ^ " == " ^ tp_to_string t1
| Sub(t0, m, t1) ->
tp_to_string t0 ^ " <" ^
tps_to_string m ^ "< " ^
tp_to_string t1
| In(t0, t1) ->
tp_to_string t0 ^ " <= " ^ tp_to_string t1)
in "[" ^ t ^ "]"

let prop_activevars: prop -> tp list =
function p -> match p with
| Eq(t0,t1) ->
list_union tp_eq (tp_freevars t0) (tp_freevars t1)
| Sub(t0, m, t1) ->
list_union tp_eq (tp_freevars t0)
(list_intersection tp_eq m (tp_freevars t1))
| In(t0, t1) ->
list_union tp_eq (tp_freevars t0) (tp_freevars t1)

let rec props_activevars: prop list -> tp list =
function pp -> match pp with
| [] -> []
| p::pp -> list_union tp_eq (prop_activevars p)
(props_activevars pp)

let prop_sub: (tp * tp) list -> prop -> prop =
let m_sub: (tp * tp) list -> tp list -> tp list =
function ss -> function m ->
let m = list_map (tp_sub ss) m in
let m = list_map tp_freevars m in
let m = list_flatten m in
m
in
function ss -> function p -> match p with
| Eq(t0,t1) -> Eq(tp_sub ss t0, tp_sub ss t1)
| Sub(t0, m, t1) -> Sub(tp_sub ss t0, m_sub ss m, tp_sub ss t1)
| In(t0, t1) -> In(tp_sub ss t0, tp_sub ss t1)

(********************************************************************)

let rec ass_remove: xp -> (xp * tp) list -> (xp * tp) list =
function e0 -> function l -> match l with
| [] -> []
| ((e1,t)::l) ->
if xp_eq e0 e1 then ass_remove e0 l
else (e1,t)::(ass_remove e0 l)

let rec ass_types: xp -> (xp * tp) list -> tp list =
function e0 -> function l -> match l with
| [] -> []
| ((e1,t)::l) ->
if xp_eq e0 e1 then t::ass_types e0 l
else ass_types e0 l

let rec prop_bu: tp list -> xp -> ((xp * tp) list * prop list * tp) =
function m -> function e -> match e with
| Var(s0) ->
if compare s0 "true" == 0 then
([], [], TConst("bool"))
else if compare s0 "false" == 0 then
([], [], TConst("bool"))
else if compare s0 "0" == 0 then
([], [], TConst("int"))
else
let t = tp_fresh_var () in
([(e,t)], [], t)
| App(e0,e1) ->
let t = tp_fresh_var () in
let (a0,c0,t0) = prop_bu m e0 in
let (a1,c1,t1) = prop_bu m e1 in
(a0 ++ a1, c0 ++ c1 ++ [Eq(t0, TArrow(t1, t))], t)
| Abs(s0,e0) ->
let t = tp_fresh_var () in
let (a0,c0,t0) = prop_bu (m++[t]) e0 in
let a1 = ass_remove (Var(s0)) a0 in
let tt = ass_types (Var(s0)) a0 in
let c1 = list_map (fun t1 -> Eq(t, t1)) tt in
(a1, c0 ++ c1, TArrow(t,t0))
| Let(s0,e0,e1) ->
let (a0,c0,t0) = prop_bu m e0 in
let (a1,c1,t1) = prop_bu m e1 in
let a2 = ass_remove (Var(s0)) a1 in
let tt = ass_types (Var(s0)) a1 in
let c2 = list_map (fun t -> Sub(t, m, t0)) tt in
(a0 ++ a2, c0 ++ c1 ++ c2, t1)

let prop_active: prop -> prop list -> bool =
function p -> function pp -> match p with
| Sub (t0,m,t1) ->
let ss0 = tp_freevars t1 in
let ss1 = list_difference tp_eq ss0 m in
let ss2 = props_activevars pp in
let ss3 = list_intersection tp_eq ss1 ss2 in
not (list_eq tp_eq ss3 [])
| _ -> false

let rec props_split: tp list -> prop list
-> (prop * prop list) option =
function av -> function pp -> match pp with
| [] -> None
| (Eq (t0,t1)::pp) -> Some(Eq(t0,t1), pp)
| (In (t0,t1)::pp) -> Some(In(t0,t1), pp)
| (Sub (t0,m,t1)::pp) ->
(* (FV(t1) - M) \inter activevars(C) == {} *)
let ss0 = tp_freevars t1 in
let ss1 = list_difference tp_eq ss0 m in
let ss2 = list_intersection tp_eq ss1 av in
if list_eq tp_eq ss2 [] then
Some(Sub(t0,m,t1), pp)
else
let ps = props_split av pp in
match ps with
| None -> None
| Some(p,pp) -> Some(p,Sub(t0,m,t1) >+ pp)

let rec solve: prop list -> subs =
function pp ->
let _ = print_string "solving: " in
let _ = list_map print_string (list_map prop_to_string pp) in
let _ = print_newline () in
let sp = props_split (props_activevars pp) pp in
match sp with
| None -> []
| Some(p,pp) ->
match p with
| Eq (t0,t1) ->
let ss = tp_mgu t0 t1 in
solve (list_map (prop_sub ss) pp) ++ ss
| Sub (t0,m,t1) ->
solve (In(t0, generalize m t1)::pp)
| In (t0,t1) ->
solve (Eq(t0, instantiate t1)::pp)

let stars = "************************************" ^ "
**********************************\n"
let xp_props: xp -> unit =
function e ->
let (a,c,t) = prop_bu [] e in
let _ = print_string stars in
let _ = print_string "term: " in
let _ = print_string (xp_to_string e) in
let _ = print_newline () in
let _ = print_string "constraints: " in
let _ = list_map print_string (list_map prop_to_string c) in
let _ = print_newline () in
let ss = solve c in
let _ = print_string "substitutions: " in
let _ = list_map print_string (list_map sub_to_string ss) in
let _ = print_newline () in
let ss = subs_sub ss in
let _ = print_string "unfolded substitutions: " in
let _ = list_map print_string (list_map sub_to_string ss) in
let _ = print_newline () in
let _ = print_string "derived type: " in
let _ = print_string (tp_to_string t) in
let _ = print_newline () in
let _ = print_string "type: " in
let _ = print_string (tp_to_string (tp_sub ss t)) in
let _ = print_newline () in
let _ = print_string stars in
()

(********************************************************************)

let e0 = Abs("m", Let("y", Var("m"), Let("x", App(Var("y"),
Var("true")), Var("x"))))
let e1 = Let("id", Abs("x", Let("y", Var("x"), Var("y"))),
App(Var("id"), Var("id")))
let e2 = Let("id", Abs("x", Let("y", Var("x"), Var("y"))),
App(Var("id"), Var("true")))

let _ = xp_props e0
let _ = xp_props e1
let _ = xp_props e2