open Absyn Util

fun lookup _ [] = NONE
  | lookup x ((a,b)::t) = if x=a then SOME b else lookup x t


fun chkUniqL l =
    case l of [] => ()
	    | ("_", _)::t => chkUniqL t
	    | (h,_)::t =>
	      case List.find (fn (x,u) => x=h) t
	       of NONE => chkUniqL t
		| SOME (v,l) => error l ("Repeated variable " ^ v)

(* Check that datatype definition doesn't involve units *)

fun unitFree (TYPEdec (_,_,constrs, _)) =
    let
	fun ok (TYPEcon(_,(types,tloc),_,loc)) =
	    if member UNITty types
	    then error tloc "The unit type is not allowed in type definitions"
	    else ()
    in
	app ok constrs
    end

(* or arrows *)

fun arrowFree (TYPEdec (_,_,constrs,_)) =
    let
	fun arrowy (ARROWty _) = true | arrowy _ = false

	fun ok (TYPEcon(_,(types,tloc),_, loc)) =
	    if List.exists arrowy types
	    then error tloc "Arrow types are not allowed in type definitions (yet)"
	    else ()
    in
	app ok constrs
    end


(* Warn of unused type parameters (could still be fooled by mutually recursive types) *)


fun chkTyParams datatypes =
let

    val arities = List.foldl (fn (TYPEdec (tyArgs,tname,_,loc), acc)
				 => (nameOf tname, (chkUniqL tyArgs; length tyArgs))::acc) [] datatypes

    fun typesIn loc (t, acc) =
	let fun types1 (t,acc) =
		case t of
		    TVARty v => let val a = #name v 
				in if member a acc then acc else a::acc
				end
		  | ARRAYty t1 => types1 (t1, acc)
		  | PRODUCTty tys => List.foldl types1 acc tys
		  | CONty (tys, tname) =>
		    let in
			case
			    lookup tname arities of
			    NONE => error loc "Unknown type"
			  | SOME n => if n <> length tys
				      then error loc "Arity mismatch"
				      else List.foldl types1 acc tys
		    end
		  | ARROWty (t1, t2) =>
		    let val acc' = types1 (t1, acc)
		    in types1 (t2, acc')
		    end
		  | _ => acc
	in
	    types1 (t,acc)
	end

    fun typesBound (TYPEdec (tyArgs, _, constrs, loc)) =
	let
	    val typesUsed = List.foldl (fn (TYPEcon (_,(tys,tloc),_,conloc),acc) =>
					   List.foldl (typesIn tloc) acc tys) [] constrs
	in
	    app (fn (a,l) => if member a typesUsed then ()
			 else warn l ("type variable " ^ a ^ " unused in type declaration")) tyArgs
	end
in
    app typesBound datatypes
end

(* Check that all constructors and functions are defined and are applied
   to the correct number of arguments (some of this is done elsewhere, but never
   mind).  Also check that all match statements are exhaustive and irredundant. *)

(* We could also check that all datatypes and internal classes have been
   defined in the program,  so that things like "fun f(x:int): intt" will be
   caught early.  At present you get a type error,  but it could be more
   informative. *)


fun checkArgs (PROG(datatypes, valdecs, classdefs, fblocks)) =
    let
	val allfuns = Asyntfn.collapse fblocks

	val funsizes = (* list of pairs (fname, #args) *)
	    map (fn FUNdef(name,args,i,_,_) => (nameOf name, length args)) allfuns
            @ List.mapPartial (fn
			       (VALdec(name, ty, _)) =>
				  SOME (nameOf name, length (Util.arrowTyToTyList ty) - 1)
			     | _ => NONE  )
                              valdecs
	val builtin_funsizes = Perv.builtinArgSizes ()


	(* check for repetition of formal arguments *)

	fun no_repetitions [] _ _ = ()
	  | no_repetitions (UNITvar::t) used loc = no_repetitions t used loc
	  | no_repetitions ((VAR(h,_))::t) used loc =
	    if member h used then error loc ("Error: repeated formal argument " ^ h)
	    else no_repetitions t (h::used) loc

	val () = app (fn FUNdef(_, args,_, _, loc) => no_repetitions args [] loc) allfuns


	(* Collect information about datatype declarations *)

	val allcons = List.foldl (fn (TYPEdec(_,_, cons,_),acc) => cons@acc) [] datatypes

	val consizes = map (fn TYPEcon(name, (args,_), _, _) => (nameOf name, length args)) allcons
	(* (constructor, #args) list *)

	val conspace = map (fn TYPEcon(name, _, usage, _) => (nameOf name, usage)) allcons

	val typecons = map (fn TYPEdec(_,name, cons, _) =>
			       (nameOf name, map (fn TYPEcon (cname,_,_,_)
						     => nameOf cname) cons)) datatypes
       (* (datatype name, constructor name list) list *)

	val contypes = List.concat (map (fn (x,l) => map (fn a => (a, x)) l) typecons)
	(* (constructor name, datatype name) list *)



	fun checkVal v = ()

	and checkExp expr =
	    case expr of
	    VALexp (w,_) => checkVal w
	  | UNARYexp (oper, e, _) => checkExp e
	  | BINexp (oper, e, e', _) => (checkExp e; checkExp e')
	  | IFexp (e, e1, e2, _) => (checkExp e; checkExp e1; checkExp e2)
	  | MATCHexp (e,l,loc) => (checkExp e; case l of [] => error loc "Match statement with no rules"
					     | _ => checkMatchRules loc l)
	  | LETexp (v, e, e',_) => (checkExp e; checkExp e')
	  | APPexp (e, args, GLOBAL, loc) => (checkExp e; app checkExp args)
	  | APPexp (e, args, LOCAL, loc) => (checkExp e; app checkExp args)
	  | APPexp (e, args, BUILTIN, loc) =>
	    let in case e
            of
	       VALexp (VARval (f,_,_),loc') =>
	       let in case lookup f builtin_funsizes of
			  NONE => error loc' ("Undefined built-in function '" ^ f ^ "'")
			| SOME n => if n = length args then app checkExp args
				    else error loc' ("Built-in function '" ^ f ^
						     "' requires " ^ (plural n "argument")
						     ^ " (not " ^ (Int.toString (length args) ^ ")"))
	       end
	     | _ => error loc "Non-value applied as built-in function: normalisation error?"
	    end
	  | APPexp (e, args, EXTERN, loc') =>
	    let in case e
            of
	       VALexp (VARval f, loc) => app checkExp args
	     | _ => error loc' "Non-value applied as function: normalisaton error?"
	    end
	  | CONexp (v, args, addr, loc) =>
	    let val () = case lookup (nameOf v) consizes of
		       NONE => error loc ("Unknown constructor " ^ nameOf v)
		     | SOME n => if n = length args then app checkExp args
				    else error loc ("Constructor " ^ nameOf v ^
						     " requires " ^ (plural n "argument")
						   ^ ", not" ^ Int.toString (length args))
	    in (* Check that we don't attempt to say X@d where X is heap-free *)
		case addr of NONE => ()
			   | SOME _ =>
			     let in case lookup (nameOf v) conspace of
					NONE => error loc ("Unknown constructor " ^ nameOf v)
				      | SOME HEAP => ()
				      | SOME NOHEAP =>
						 error loc ("ERROR: constructor is heap-free")
					end
	    end
	  | TYPEDexp (e,t,_) => checkExp e
	  | COERCEexp (e,t,_) => checkExp e
          | NEWexp (class,es,_) => app checkExp es
          | SUPERMAKERexp (es,_) => app checkExp es
          | INVOKEexp (obj, mname, es,_) => app checkExp (obj::es)
          | UPDATEexp (x,e,_) => checkExp e
          | GETexp (obj,x,_) => checkExp obj
          | SGETexp x => ()
	  | ASSERTexp (e,as1,as2,_) => checkExp e
	  | LAMexp (v,e,_) => checkExp e


	and checkMatchRules matchLoc rules =
	    let
		fun checkrule tname Tcons (MATCHrule ((con, cloc), args, d, e, loc)) =
		let
		    val () = chkUniqL args
		    val () = case d of SOMEWHERE (n,nloc) =>
				       if memberL n args
				       then error nloc "Diamond bound in pattern"
				       else ()
				     | _ => ()

		    val () = case lookup con consizes of
			     NONE => error cloc ("Unknown constructor " ^ con)
			   | SOME n => if n = length args then ()
				       else error loc ("Constructor " ^ con ^
							" requires " ^ (plural n "argument"))


		(* For Steffen's benefit we allow expressions such as "match l with []@_ -> ..."
		   for heap-free constructors.  We definitely don't allow diamonds in
		   such contexts. *)

		    val () = case lookup con conspace of
			     NONE => error loc ("Unknown constructor " ^ con)
			   | SOME HEAP => ()
			   | SOME NOHEAP =>
			     let in case d of
					NOWHERE => ()
				      | SOMEWHERE _ => error loc (
						       "invalid annotation: constructor "
						       ^ con ^ " is heap-free")
				      | DISPOSE => warn loc ("CONSTRUCTOR " ^ con ^ " IS HEAP-FREE")
			     end

		    val () =
			if not (member con Tcons)
			then  error loc ("Constructor " ^ con ^ " belongs to type "
				   ^ (valOf (lookup con contypes))
				   ^ ",  but this match rule is for type "
				   ^ tname)
			else ()
		in
		    checkExp e
		end
		  | checkrule tname Tcons (OOMATCHrule(pat, e, _)) = ()
	    in
		case hd rules of
		    (MATCHrule(con,_,_,_,loc)) =>
		    let
			val T = case lookup (nameOf con) contypes of
				    NONE => error loc ("Unknown constructor " ^ (nameOf con))
				  | SOME t => t

			val Tcons = valOf (lookup T typecons)

			val () = app (checkrule T Tcons) rules

			fun conOf (MATCHrule (con, _, _, _, _)) = nameOf con
			  | conOf (OOMATCHrule (_,_,l)) = Util.error l "Malformed match rule"

			fun checkset [] = ()
			  | checkset (m::t) =
			    let val thisCon = conOf m
			    in
				case List.find (fn n => (conOf n = thisCon)) t
				 of SOME (MATCHrule(_,_,_,_,duploc))
				    => error duploc ("Duplicate rule for constructor " ^ thisCon)
				  | SOME _ => error loc "Whoa!!"
				  | NONE =>  checkset t
			    end

			val () = checkset rules

			fun used con =
			    if List.exists (fn m => (con = conOf m)) rules
			    then ()
			    else error matchLoc ("Inexhaustive match:  no rule for constructor " ^ con)

			val () = app used Tcons
		    in () end
		  | (OOMATCHrule _) => ()
	    end

	    fun checkFunDef (FUNdef(_,_,_,body,_)) = checkExp body
	    fun checkFunBlock (FUNblock b) = app checkFunDef b

    in
	app checkFunBlock fblocks
    end


fun syncheck (prog as (PROG(datatypes, _,_,_))) =
let
(*    val () = app unitFree datatypes*)
    val () = app arrowFree datatypes
    val () = chkTyParams datatypes
    val () = checkArgs prog
in
    ()
end
