(*
   Author:   Steffen Jost <jost@informatik.uni-muenchen.de>
   Name:     $Name:  $
   File:     $RCSfile: typcheck.ml,v $
   Id:       $Id: typcheck.ml,v 1.1 2004/12/07 17:30:41 sjost Exp $ 

	
   What this File is all about:
   ----------------------------
   Type checking an expression and defining/delivering a typed expression.

   PROBLEM: 
   A full typeinference seemed to much work to implement, hence we
   decided to do typechecking only. Here bidirectional typechecking must be
   used, since we must be able to typecheck terms like "fun x->x"
   and "let rec x = x in x". With the availability of Camelot, we also
   decided against implementing the bidirectional typechecking and
   instead insist on type annotations on all "fun","let" and "let rec" 
   terms.


   ToDos: 

   - the type checker is only partially prepared to handle the type TvarTyp. It does not feed back any instantiations! See 'equal'. 

   - should we move the whole bunch to syntax.ml, wrapping the actual typchecker within its own module?
*)

open Support
open Common
open Types
open Syntax


class typed_expression info p_t_expr v_ctxt ty =
  object (self: 'self)
    inherit expression info p_t_expr
    val  v_t: typ             = ty       (* Which version is better? With or without val? *)
    method t: typ             = v_t
    method ctxt: typ context  = v_ctxt

    method pretty_print: unit = self#pretty_print_aux 0
    method private pretty_print_aux: int -> unit =
      let step: int = 2 in
      let space: int -> unit =
	fun i -> print_string (String.make i ' ')
      in
      fun i ->
	begin
	  space i;
	  begin
	    match v_e with
	    | ValueExp(v)                         -> 
		print_string("val(..)")
	    | ConstrExp(constr, valus, dia)       ->
		print_string("con("^constr^"[..])")
	    | FunExp(fid, ty, fbody)              ->
		begin
		  print_string ("fun("^fid^")->\n"); 
		  fbody#pretty_print_aux (i+step);
		  space(i);
		  print_string ("endfun("^fid^")")
		end
	    | AppExp(fid, argid)                  ->
		begin
		  print_string ("app("^fid^","^argid^")")
		end
	    | LetExp(letvar, ty, expra, exprb)    ->
		begin
		  print_string ("let("^letvar^")=\n");
		  expra#pretty_print_aux (i+step);
		  space(i);
		  print_string "in\n";
		  exprb#pretty_print_aux (i+step);
		  space(i);
		  print_string ("endlet("^letvar^")")
		end
	    | SeqExp(expra, exprb)                ->
		begin
		  print_string ("seq");
		  expra#pretty_print_aux (i+step);
		  space(i);
		  print_string ";\n";
		  exprb#pretty_print_aux (i+step);
		  space(i);
		  print_string ("endseq")
		end
	    | RecExp(recvar, ty, expra, exprb)    ->
		begin
		  print_string ("rec("^recvar^")=\n");
		  expra#pretty_print_aux (i+step);
		  space(i);
		  print_string "in\n";
		  exprb#pretty_print_aux (i+step);
		  space(i);
		  print_string ("endrec("^recvar^")")
		end
	    | AndExp(deflist,bodyexpr) ->
		let print_def: variable * (typ option) * 'self -> unit =
		  fun (recvar, tyopt, recexpr) ->
		    begin
		      space(i);
		      print_string ("and("^recvar^")=\n");
		      recexpr#pretty_print_aux(i+step);
		    end
		in
		begin
		  print_string ("let rec \n");
		  List.iter print_def deflist;
		  space(i);
		  print_string ("in\n");
		  bodyexpr#pretty_print_aux (i+step);
		  space(i);
		  print_string ("endand")
		end
	    | IfExp(ifvalu, thenexpr, elseexpr)   ->
		begin
		  print_string ("if (..) then \n");
		  thenexpr#pretty_print_aux (i+step);
		  space(i);
		  print_string "else\n";
		  elseexpr#pretty_print_aux (i+step);
		  space(i);
		  print_string ("endif")
		end
	    | LinIExp(expra,exprb)                ->
		begin
		  print_string ("[|\n");
		  expra#pretty_print_aux (i+step);
		  space(i);
		  print_string "|";
		  exprb#pretty_print_aux (i+step);
		  space(i);
		  print_string "|]"
		end
	    | LinEExp(fst, var)                   ->
		begin
		  if fst 
		  then print_string ("fst("^var^")")
		  else print_string ("snd("^var^")")
		end
	    | MatchExp(mvar, mrules)              ->  
		let print_mrule (Matchrule(_,constr,args,_,mexpr)) =
		  begin
		    space(i);
		    print_string ("|"^constr^"(..)->\n");
		    mexpr#pretty_print_aux(i+step);
		  end
		in
		begin
		  print_string ("match ("^mvar^") with\n");
		  (List.iter print_mrule mrules);
		  space(i);
		  print_string"endmatch"
		end
	  end;  
	  print_string " : ";
	  print_string (Types.to_string v_t);
	  print_newline ()
	end
	  
  end

let expression_of_typed_expression e = (e: typed_expression :> expression) (* Explicit coercion is necessary in OCaml, but we will not use it though for our purpose. Can't hurt to have it available though. *)

(* typed_expression_of_expression: context -> expression -> typed_expression = expression   which is defined below *)


(* Basic type checking *)    

let unop: unaryoperator -> (plain_typ -> plain_typ) = (* returns a function, given an input type for the operator returning the operators result type *)
  let simple_uop: plain_typ -> plain_typ -> plain_typ =
    fun param t -> 
      if   (Types.equal (fakeinfo param) (fakeinfo t)) 
      then param 
      else raise (Invalid_argument "Unary operator applied to wrong typ.")
  in
  fun uop ->  (* The inference assumes that all built-in operators do not affect the heap! *)
    match uop.v with
    | NotOp     -> simple_uop BoolTyp
    | UMinusOp  -> simple_uop IntTyp
    | UFminusOp -> simple_uop FloatTyp
	  
let binop: binaryoperator -> ((plain_typ * plain_typ) -> plain_typ) = (* returns a function, given the input types for the operator returning the operators result type *)
  let simple_bop: plain_typ -> (plain_typ * plain_typ) -> plain_typ =
    fun param (tya,tyb) -> 
      if   (Types.equal (fakeinfo param) (fakeinfo tya)) && (Types.equal (fakeinfo param) (fakeinfo tyb))
      then param 
      else raise (Invalid_argument "Binary operator applied to wrong types.")
  in 
  fun bop ->
    match bop.v with
    | TimesOp   
    | DivOp     
    | PlusOp    
    | MinusOp   -> simple_bop IntTyp
    | FtimesOp  
    | FdivOp    
    | FplusOp   
    | FminusOp  -> simple_bop FloatTyp
    | LessOp    
    | LteqOp 
    | GreaterOp
    | GteqOp
    | EqualOp   -> 
	(function (* First three raise statements are essential to the inference! The inference assumes that all built-in operators do not affect the heap! *)
	  | (ConTyp(_,_),_)     -> raise (Invalid_argument "Equation operator cannot be applied to user defined types.")
	  | (ArrowTyp(_,_),_)   -> raise (Invalid_argument "Equation operator cannot be applied to function types.")
	  | (LinPairTyp(_,_),_) -> raise (Invalid_argument "Equation operator cannot be applied to linear pairs.")
	  | tya,tyb when (Types.equal (fakeinfo tya) (fakeinfo tyb)) -> BoolTyp 
	  | _ -> raise (Invalid_argument "Equation operator applied to incomatible types.")
	)
    | AppendOp  -> 
	(fun (tya, tyb) -> 
	  let is_cs: plain_typ -> bool = 
	    fun t -> (Types.equal (fakeinfo StringTyp) (fakeinfo t)) || (Types.equal (fakeinfo CharTyp) (fakeinfo t)) 
	  in
	  if (is_cs tya) && (is_cs tyb) 
	  then StringTyp 
	  else raise (Invalid_argument "Append operator applied to wrong types.")
	)
    | AndOp     
    | OrOp
    | AndalsoOp       
    | OrelseOp  -> simple_bop BoolTyp
    | ModOp     -> simple_bop IntTyp      


let rec valu: (typ #context) -> valu -> typ =
  fun ctxt vl ->
    let pl_ty =
      match vl.v with
      | VarVal(v)   -> stripinfo (ctxt#lookup v)
      | UnitVal     -> UnitTyp
      | BoolVal _   -> BoolTyp
      | IntVal _    -> IntTyp
      | FloatVal _  -> FloatTyp
      | CharVal _   -> CharTyp
      | StringVal _ -> StringTyp
      | UnaryOpVal(op,x)   -> 
	  let ty = stripinfo (valu ctxt x) in (unop op) ty
      | BinaryOpVal(op,x,y) -> 
	  binop op (stripinfo(valu ctxt x), stripinfo(valu ctxt y))
    in {i= vl.i; v= pl_ty}
      

let rec expression: (typ #context) -> expression -> typed_expression = (* infers the type of an expression *) 
  fun ctxt expr ->
    try_withinfo expr#i "Typecheck failed:" begin lazy
      begin match expr#e with
      | ValueExp(v) as tpe_e ->  (* Note that tpe_e might be of type typed_expression plain_expression, while expr#e is not. *)
	  new typed_expression expr#i tpe_e ctxt (valu ctxt v) 
      | ConstrExp(constr, valus, dia) as tpe_e ->  
	  let ci = !the_contab#find constr in
          let _ = (* Verify that arguments are of proper type *)
	    let _ = match dia with 
	    | New -> ()
	    | Reuse(dv) -> 
		if not (Types.equal (ctxt#lookup dv) (fakeinfo DiamondTyp))
		then errAt expr#i ("Identifier '"^dv^"' is not of type diamond.") 
	    in
	    List.iter2 
	      (fun vl t -> 
		if not (Types.equal (valu ctxt vl) t)
		then errAt vl.i ("Constructor argument is of incompatible type.")
	      ) valus ci#arg_typs
	  in new typed_expression expr#i tpe_e ctxt ci#own_typ
      | FunExp(fvar, dom_opt, fbody)      -> 
	  begin match dom_opt with
	  | None -> errAt expr#i ("Function definition without type declaration for '"^fvar^"'.")
	  | Some(dom) ->
	      let rng_expr = expression (ctxt#bind fvar dom) fbody in
	      let tpe_ty   = addinfo expr#i (ArrowTyp(dom,rng_expr#t)) in
	      new typed_expression expr#i (FunExp(fvar, dom_opt, rng_expr)) ctxt tpe_ty
	  end
      | AppExp(fid, argid) as tpe_e  -> 
	  let fty =
	    if   ctxt#mem fid 
	    then ctxt#lookup fid (* defined function *)
	    else Builtin.typ fid (* built-in function *)
	  in
	  begin match fty.v with
	  | ArrowTyp(dom,rng) ->
	      let argty = (ctxt#lookup argid) in
	      if   Types.equal dom argty
	      then new typed_expression expr#i tpe_e ctxt rng
	      else errAt expr#i ("Argument for '"^fid^"' is of type '"^(string_of_typ argty)^"',"^Support.exn_sep^"but must be of type '"^(string_of_typ dom)^"'.")
	  | _ -> errAt expr#i ("Identifier '"^fid^"' is of type '"^(string_of_typ fty)^"', but is applied like a function.")
	  end
      | LetExp(letvar, ty_opt, expra, exprb)  -> 
	  let tpe_a = expression ctxt expra in
	  let tpe_b = expression (ctxt#bind letvar tpe_a#t) exprb in
	  let _ = match ty_opt with (* Looks like bad style: only evaluation to see if an exception occurs... *)
	  | None      -> ()
	  | Some(lvt) -> 
	      if Types.equal lvt (tpe_a#t)
	      then ()
	      else errAt expr#i ("Given type for '"^letvar^"':'"^(Types.to_string lvt)^"' does not match '"^(Types.to_string tpe_a#t)^"'.")
	  in
	  let tpe_e = LetExp(letvar, Some(tpe_a#t), tpe_a, tpe_b) in
	  new typed_expression expr#i tpe_e ctxt tpe_b#t
      | SeqExp(expra, exprb)  -> 
	  let tpe_a = expression ctxt expra in
	  let tpe_b = expression ctxt exprb in
	  new typed_expression expr#i (SeqExp(tpe_a,tpe_b)) ctxt tpe_b#t
      | RecExp(recvar, ty_opt, expra, exprb)    -> 
	  begin match ty_opt with
	  | None      -> errAt expr#i ("No type declaration for recursive identifier '"^recvar^"'.")
	  | Some(rvt) -> 
	      let ctxt_rv = ctxt#bind recvar rvt in
	      let tpe_a = expression ctxt_rv expra in
	      let tpe_b = expression ctxt_rv exprb in
	      let tpe_e = RecExp(recvar, Some(tpe_a#t), tpe_a, tpe_b) in
	      new typed_expression expr#i tpe_e ctxt tpe_b#t
	  end
      | AndExp(deflist,bodyexpr) -> 
	  let ctxt_rvs: (typ #context) = (* Add all mutual recursive definitions to context first *)
	    let collect_rectypes: (typ #context) -> (variable * (typ option) * 'self) -> (typ #context) = 
	      fun ctxt (recvar, recty_opt, recbody) ->
		begin match recty_opt with
		| None      -> errAt recbody#i ("No type declaration for recursive identifier '"^recvar^"'.")
		| Some(rvt) -> ctxt#bind recvar rvt
		end
	    in List.fold_left collect_rectypes ctxt deflist 
	  in 
	  let typed_deflist = (* Typcheck all recursive definitions now *)
	    let type_def: (variable * (typ option) * expression) -> (variable * (typ option) * typed_expression) =
	      fun (recvar, recty_opt, recbody) ->
		let tpe = expression ctxt_rvs recbody in
		(recvar, Some(tpe#t), tpe)
	    in List.map type_def deflist
	  in 
	  let tpe_body = expression ctxt_rvs bodyexpr  in
	  let tpe_e    = AndExp(typed_deflist,tpe_body) in 
	  new typed_expression expr#i tpe_e ctxt tpe_body#t
      | IfExp(ifvalu, thenexpr, elseexpr)  -> 
	  begin match (valu ctxt ifvalu).v with
	  | BoolTyp -> 
	      let tpe_then = expression ctxt thenexpr in
	      let tpe_else = expression ctxt elseexpr in
	      if Types.equal tpe_then#t tpe_else#t
	      then 
		let tpe_e = IfExp(ifvalu, tpe_then, tpe_else) in
		new typed_expression expr#i tpe_e ctxt tpe_then#t
	      else errAt expr#i ("If-branches are not of compatible type.")
	  | _       -> errAt expr#i ("If-value is not of boolean type.")
	  end
      | LinIExp(expra,exprb)  -> 
	  let tpe_a  = expression ctxt expra in
	  let tpe_b  = expression ctxt exprb in
	  let tpe_e  = LinIExp(tpe_a,tpe_b) in
	  let tpe_ty = {i=expr#i; v=(LinPairTyp(tpe_a#t, tpe_b#t))} in
	  new typed_expression expr#i tpe_e ctxt tpe_ty
      | LinEExp(fstsnd, var)  as tpe_e -> 
	  begin match (ctxt#lookup var).v with
	  | LinPairTyp(fst_ty,snd_ty) ->
	      if fstsnd 
	      then (* fst() *)
		new typed_expression expr#i tpe_e ctxt fst_ty
	      else (* snd() *)
		new typed_expression expr#i tpe_e ctxt snd_ty
	  | _ -> errAt expr#i ("Identifier '"^var^"' is not of linear pair type.")
	  end
    | MatchExp(mvar, mrules)              -> 
	let mrule: expression matchrule -> typed_expression matchrule = 
	  function (Matchrule(info, cnstr, vars, dia_opt, mrexpr)) ->
	    let tpe_mr =        (* This style of cascaded lets might be more efficient, but surely is less readable than sequential lets... *)
	      let ctxt' = 
		let ctxt_pre =
		  let ci  = !the_contab#find cnstr in
		  List.fold_left2 
		    (fun acc_ctxt v t -> acc_ctxt#bind v t) 
		    ctxt 
		    vars 
		    ci#arg_typs
		in match dia_opt with
		| None          ->         (* read_only *) 
		    ctxt_pre
		| Some(New)     ->         (* destructive, anonymous *)
		    ctxt_pre#remove mvar 
		| Some(Reuse(d)) ->        (* destructive, named *)
		   (ctxt_pre#remove mvar)#bind d (addinfo info DiamondTyp)
	      in expression ctxt' mrexpr 
	    in Matchrule(info, cnstr, vars, dia_opt, tpe_mr) 
	in 
	let tpe_mrules = List.map mrule mrules in
	let tpe_e  = MatchExp(mvar, tpe_mrules) in
	let tpe_ty = 
	  match tpe_mrules with 
	  | (Matchrule(info, cnstr, vars, dia_opt, t_mrexpr))::tl_mrules ->
	      List.fold_left 
		(fun acc_ty (Matchrule(info, cnstr, vars, dia_opt, t_mrexpr)) -> 
		  if   Types.equal acc_ty t_mrexpr#t
		  then t_mrexpr#t
		  else errAt info ("Type of matchrule not compatible with previous matchrule.")
		)
		t_mrexpr#t
		tl_mrules
	  | [] -> errAt expr#i ("There must be at least one matchrule per match.")
	in new typed_expression expr#i tpe_e ctxt tpe_ty
      end
    end

