(* Convert Camelot abstract syntax to normalised form *)
(* Don't do anything about higher-order functions  yet *)


local
open Util;
open Absyn;
structure N = Normsyn
val () = Normsyn.required   (* for mosmldep *)
val getLoc = Asyntfn.getU

fun NormaliseError s = Util.ierror ("[Normalise.sml]: " ^ s)
fun lamError loc = Util.error loc ("[Normalise.sml]: "
				   ^ "this expression should have been lambda-lifted by now")


fun valExp x loc = VALexp (VARval (x, LOCAL, loc), loc)  (* FIX LOCAL *)
in


(* There was a problem when reingesting normlised programs in that
   they contained names which could clash with newly-generated tempnames.
   Normalisation won't generate such names, but there are calls to
   tempName elsewhere (eg Optimise.sml) which could cause problems.

   names. *)

fun startsWithQuery x =
    String.sub(x,0) = #"?"
    handle Subscript => NormaliseError "empty name"


fun getFBlockNames (FUNblock b) = map (fn (FUNdef((fname,_),_,_,_,_)) => fname) b
fun getGlobalNames funBlocks = List.concat (map getFBlockNames funBlocks)


local
    val names = ref (Binaryset.empty(String.compare))
    fun addname s = names := Binaryset.add(!names,s)
    fun makename s n =
	let
	    val s' = s^"$"^Int.toString(n)
	in
	    if Binaryset.member(!names,s') then makename s (n+1) else (addname s'; s')
	end
in
    fun resetNames globals = names := Binaryset.addList (Binaryset.empty(String.compare),
						      getGlobalNames globals)
    fun getNames() = Binaryset.listItems(!names)
    fun addName s = addname s
    fun nameUsed s = Binaryset.member(!names,s) orelse startsWithQuery s
    fun newName s = if nameUsed s then makename s 0 else (addName s; s)
end

local
    val counter = ref 0
in
    fun resetLocalNames () = counter := 0
    fun tempName () =
    let
	val n = !counter
	val () = counter := (!counter) + 1
	val name = "?t" ^ Int.toString n
	val () = addName name
    in
	name
    end
end

(* Replace all free occurrences of var v in e by w *)

fun isArg x l =
    case l of [] => false
	    | h::t =>
	      case h of
		  VAR(y,_) => x=y orelse isArg x t
		| UNITvar => isArg x t

infix 5 ==> (* substitute in expression *)
infix 5 --> (* substitute in name *)
infix 5 --?> (* substitute in name option *)
fun (v==>w) expr =
    case expr of
	VALexp (e,loc) =>
	    let in
		case e of
		    VARval (x,ext,l) => if x=v then VALexp (VARval (w,ext,l), loc) else expr
		  | _ => expr
	    end
      | UNARYexp (u,e, loc) => UNARYexp(u, (v==>w)e, loc)
      | BINexp (b,e1,e2,loc) => BINexp(b, (v==>w)e1, (v==>w)e2, loc)
      | IFexp (e, e1, e2,loc) => IFexp ((v==>w)e, (v==>w)e1, (v==>w)e2,loc)
      | MATCHexp (e,l, loc) => MATCHexp((v==>w)e,
					map (fn MATCHrule(var,vs,a, f,l)=>
						(MATCHrule(var,vs,substAddr v w a, (v==>w)f,l))
					      | OOMATCHrule(ANYCLASSpat, e, l) =>
						(OOMATCHrule(ANYCLASSpat, (v==>w)e, l))
					      | OOMATCHrule(CLASSpat((x,xl), c), e, l) =>
						OOMATCHrule(CLASSpat(((v-->w)x,xl), c), (v==>w)e, l))
				       l, loc)
      | CONexp(c,l,a, loc) => CONexp(c, map (v==>w) l, (v--?>w)a, loc)
      | APPexp (e, l,ext, loc) => APPexp ((v==>w)e, map (v==>w) l,ext, loc)
      | LETexp ((x,xl), e, e', loc) =>
	if x=v then
	    LETexp((x,xl), (v==>w)e, e', loc)  (* WATCH OUT *)
	else
	    LETexp((x,xl), (v==>w)e, (v==>w)e', loc)
      | TYPEDexp (e,t, loc) => TYPEDexp((v==>w)e, t, loc)
      | COERCEexp (e,t,loc) => COERCEexp((v==>w)e, t, loc)
      | INVOKEexp (obj, mname, es, loc) => INVOKEexp ((v==>w) obj, mname, map (v==>w) es, loc)
      | NEWexp (class, es, loc) => NEWexp (class, map (v==>w) es, loc)
      | SUPERMAKERexp (es, loc) => SUPERMAKERexp (map (v==>w) es, loc)
      | UPDATEexp (x, e, loc) => UPDATEexp (x, (v==>w) e, loc) (* FIX: rename fields? *)
      | GETexp (obj, x, loc) => GETexp ((v==>w) obj, x, loc) (* FIX: rename fields? *)
      | SGETexp x => expr
      | ASSERTexp (e, as1, as2, loc) =>
	let
	    val () = warn loc "Assertion not adjusted during normalisation"
	in
	    ASSERTexp((v==>w)e, as1, as2, loc)
	end
      | LAMexp (largs, e, loc) =>
	if isArg v largs then expr else LAMexp(largs, (v==>w)e, loc)

and (v--?>w) a = Option.map (fn (x,l) => ((v-->w)x,l)) a
and (v-->w) a = if a=v then w else a
and substAddr v w a =
    case a of
	NOWHERE => NOWHERE
      | SOMEWHERE (n,l) => if n=v then SOMEWHERE (w,l) else a
      | DISPOSE => DISPOSE

(* Maybe we need to take care here;  what if x=w ? *)
(* Answer: this could be dangerous,  but we only ever replace names
   with newly-generated names which can't occur as let-bound variables
   (because they're of a form which isn't acccepted by the lexer) *)


(* Make sure that all let-bound variables in expression have unique names *)

fun uniq expr =
    case expr of
    VALexp (v,loc) => VALexp (v, loc)
  | UNARYexp(oper,e, loc) => UNARYexp(oper, uniq e, loc)
  | BINexp(oper,e,e',loc) => BINexp (oper, uniq e, uniq e',loc)
  | IFexp (e,e1,e2,loc) =>  IFexp (uniq e, uniq e1, uniq e2,loc)
  | MATCHexp(e,l,loc) =>
	let
	    fun uniqMrule (MATCHrule(v,vs,a,f,rloc)) =
		let
		    fun fixVars [] exp = ([], exp)
		      | fixVars ((h,hl)::t) exp =
			(* ensure that names of variables bound in matches are unique *)
			if nameUsed h then
			    let
				val h' = newName h
				val exp' = (h==>h')exp
				val (t', exp'') = fixVars t exp'
			    in
				((h',hl)::t', exp'')
			    end
			else
			    let
				val () = addName h
				val (t', exp') = fixVars t exp
			    in
				((h,hl)::t', exp')
			    end

		    val (vs', f') = fixVars vs f
		in
		    case a of
			SOMEWHERE (w,l) =>
			if nameUsed w then
			    let
				val w' = newName w
			    in
				(MATCHrule(v,vs',SOMEWHERE (w',l), uniq ((w==>w')f'), rloc))
			    end
			else
			    let
				val () = addName w
			    in
				(MATCHrule(v,vs',a, uniq f', rloc))
			    end
		      | _ => (MATCHrule(v,vs',a,uniq f',rloc))
		end
	      | uniqMrule (OOMATCHrule(CLASSpat((obj,ol), class), e, rloc)) =
		let
		    val (obj', e') =
			if nameUsed obj then
			    let val obj' = newName obj in
			      (obj', (obj==>obj')e)
			    end
			else (obj, e)
		in
		    (OOMATCHrule(CLASSpat((obj',ol), class), uniq e', rloc))
		end
	      | uniqMrule (OOMATCHrule(ANYCLASSpat, e, rloc)) =
		(OOMATCHrule(ANYCLASSpat, uniq e, rloc))
	in
	    MATCHexp(uniq e, map uniqMrule l, loc)
	end
  | CONexp(c,l,a,loc) => CONexp(c, map uniq l, a, loc) (* NEED CODE FOR @ ARGS ??*)
  | APPexp(f,p,ext, loc) => APPexp(uniq f, map uniq p, ext, loc)
  | TYPEDexp (e,t, loc) => TYPEDexp(uniq e, t, loc)
  | COERCEexp (e,t, loc) => COERCEexp(uniq e, t, loc)
  | NEWexp (class, es, loc) => NEWexp(class, map uniq es, loc)
  | SUPERMAKERexp (es, loc) => SUPERMAKERexp(map uniq es, loc)
  | INVOKEexp (obj, mname, es, loc) => INVOKEexp(uniq obj, mname, map uniq es, loc)
  | UPDATEexp (x,e, loc) => UPDATEexp(x, uniq e, loc)
  | GETexp (obj,x, loc) => GETexp(uniq obj, x, loc)
  | SGETexp x => SGETexp x
  | ASSERTexp (e,as1,as2, loc) => ASSERTexp(uniq e,as1,as2, loc)
  | LETexp(X as (x,xl),e1,e2,loc) =>
	if nameUsed x then (* clashes can happen with -a1/2 input *)
	    let
		val x'  = newName x
		val e2' = (x==>x')e2
	    in
		LETexp((x',xl), uniq e1, uniq e2',loc)
	    end
	else
	    let
		val () = addName x
	    in
		LETexp(X, uniq e1, uniq e2, loc)
	    end
  | LAMexp(args, e, loc) =>
    let fun doArgs l e acc =
	    case l of
		[] => (e, rev acc)
	      | h::t =>
		case h of
		    UNITvar => doArgs t e (h::acc)
		  | VAR(x,ty) =>
		    if nameUsed x then
			let
			    val x'  = newName x
			    val e' = (x==>x')e
			in
			   doArgs t e' (VAR(x',ty)::acc)
			end
		    else
			let val () = addName x
			in
			    doArgs t e (h::acc)
			end

	val (e', args') = doArgs args e []
    in
	LAMexp(args', uniq e', loc)
    end

fun atomic (VALexp _) = true     (* ???????????????? *)
  | atomic _ = false

(*
fun normal (loc, expr) =
  case expr of
    VALexp _ => true
  | UNARYexp(_,e) => atomic e
  | BINexp(_,e,e') => (atomic e) andalso (atomic e')
  | IFexp (e,e1,e2) => List.all normal [e,e1,e2]
  | MATCHexp(e,l) => (normal e) andalso (List.all normalMrule l)
  | APPexp(f,p,ext) => (normal f) andalso (List.all atomic p)
  | CONexp(c,l,a) => List.all atomic l   (* ??????????? *)
  | TYPEDexp(e,t) => normal e
  | COERCEexp(e,t) => normal e
  | NEWexp(class,es) => List.all atomic es
  | SUPERMAKERexp(es) => List.all atomic es
  | INVOKEexp(obj,mname,es) => (atomic obj)  andalso (List.all atomic es)
  | UPDATEexp(x,e) => normal e (* ? *)
  | GETexp (obj,x) => normal obj
  | SGETexp x => true
  | ASSERTexp(e,as1,as2) => normal e
  | LETexp(_,(_,e),_) =>
    let in case e of
	LETexp _ => false
      | _ => true
    end
  | LAMexp (_,e,_) => normal e

and normalMrule (_, MATCHrule(c,p,a,e)) = normal e  (* WHAT ABOUT PATTERNS ? *)
*)



fun normVal v =
    case v of
	 VARval    q => N.VARval q
       | CHARval   q => N.CHARval q
       | INTval    q => N.INTval q
       | FLOATval  q => N.FLOATval q
       | STRINGval q => N.STRINGval q
       | BOOLval   q => N.BOOLval q
       | UNITval   q => N.UNITval q
       | NULLval   q => N.NULLval q

fun normAtom e =
    case e of
	VALexp (v, _) => normVal v
      | _ => NormaliseError "normAtom applied to non-atomic value"

fun normList es b fres =
    let
        fun normL [] b = fres (List.rev b)
          | normL (e::t) b =
            case e of
	        VALexp (v,_) => normL t (normVal v::b)
              | _ =>
	        let
	            val u = tempName()
		    val loc = getLoc e
	        in
	            N.LETexp ((u,loc), normExpr e, normL t ((N.VARval (u, LOCAL, loc))::b), loc)
	        end
    in
        normL es b
    end

and normExpr expr =
    (* Convert to normalised Camelot syntax *)
    (* case let (* val () = TextIO.print "\n----\n" val () = Asyntfn.printExp Expr*) in expr end of*)
    case expr of
	VALexp (v, loc) => N.VALexp (normVal v, loc)
      | UNARYexp (oper, e, loc) =>
	if atomic e then N.UNARYexp (oper, normAtom e, loc)
	else let
		val x = tempName()
		val eloc = getLoc e
	    in  (* Make a token attempt to produce plausible locations for new exprs *)
		N.LETexp ((x,eloc), normExpr e, N.UNARYexp(oper, N.VARval (x,LOCAL,eloc), eloc), loc)
	    end
      | BINexp (oper, e, e', loc) =>
	if atomic e then
	    if atomic e' then N.BINexp (oper, normAtom e, normAtom e', loc)
	    else
		let
		    val y = tempName ()
		    val eloc' = getLoc e'
		in
		    (N.LETexp ((y,eloc'), normExpr e',
			       (N.BINexp(oper, normAtom e, N.VARval (y, LOCAL, eloc'), eloc')), loc))
		end
	else
	    let
		val eloc = getLoc e
		val x = tempName()
	    in
		if atomic e' then
		    N.LETexp ((x,eloc),
			      normExpr e,
			      N.BINexp(oper, N.VARval (x,LOCAL,eloc), normAtom e', loc),
			      loc)
		else
		    let
			val eloc' = getLoc e'
			val y = tempName ()
		    in
			N.LETexp ((x,eloc), normExpr e,
				N.LETexp ((y,eloc'), normExpr e',
				        N.BINexp(oper, N.VARval (x, LOCAL, eloc),
						 N.VARval (y,LOCAL, eloc'), loc), loc),
				loc)

		    end
	    end

      | IFexp (e, e1, e2, iloc)  =>
	    (* DO NOT lift E1 & E2 above the "if" *)
        let
	    val eloc  = getLoc e
	    val eloc1 = getLoc e1
	    val eloc2 = getLoc e2
	in
	    case e of
		VALexp _ => N.IFexp (N.TEST(EQUALSop, normAtom e, N.BOOLval (true, eloc), eloc),
				     normExpr e1, normExpr e2, iloc)
	      | UNARYexp (NOTop, b, _) => normExpr (IFexp(b, e2, e1, iloc))

	      | BINexp (oper, f1, f2, eloc) =>
		if atomic f1 then
		    if atomic f2 then
			N.IFexp (N.TEST(oper, normAtom f1, normAtom f2, eloc),
				 normExpr e1, normExpr e2, iloc)
		    else
			let
			    val b2 = tempName ()
			    val f2loc = getLoc f2
			in
			     N.LETexp((b2,f2loc), normExpr f2,
				      N.IFexp(
				      N.TEST(oper, normAtom f1, N.VARval (b2, LOCAL, f2loc), eloc),
				      normExpr e1, normExpr e2, iloc),
				      f2loc)  (* LOCATIONS COULD BE WRONG HERE *)
			end
		else
		    let
 			val b1 = tempName ()
			val f1loc = getLoc f1
			val f2loc = getLoc f2
		    in
			if atomic f2 then
			    N.LETexp((b1,f1loc),
				     normExpr f1,
				     N.IFexp(N.TEST(oper, N.VARval (b1, LOCAL, f1loc), normAtom f2, eloc),
					   normExpr e1, normExpr e2, iloc), iloc)
			else
			    let
				val b2 = tempName ()
			    in
				N.LETexp ((b1,f1loc),
					  normExpr f1,
					  N.LETexp ((b2,f2loc),
						    normExpr f2,
						    N.IFexp (N.TEST(oper, N.VARval (b1, LOCAL, f1loc),
								    N.VARval (b2, LOCAL, f2loc) ,eloc),
							     normExpr e1, normExpr e2,
							     iloc),
						    iloc),
					  iloc)
			    end
		    end

	   (* Default case:  this isn't too efficient for stuff like LETexp's (and also
              can lift variables into scope where they shoudn't be visible - should be OK
              due to renaming.)  Look at it again sometime. (eg if let b = x<5 in b then 0 else 1) *)
	      | _  =>
		let
		    val b = tempName()
		    val eloc = getLoc e
		in
		    N.LETexp((b,eloc),
			   normExpr e,
			   N.IFexp(N.TEST(EQUALSop, N.VARval (b, LOCAL, eloc),
					  N.BOOLval (true, eloc), eloc),
				   normExpr e1, normExpr e2, eloc),
			   eloc)
		end
	end

      | CONexp (c,l,a, loc) => normList l [] (fn res=>N.CONexp(c,res,a,loc))
      | APPexp (e, es, ext, loc)  =>
	let in
	    case e of
		VALexp (VARval (v, ext1, vl),_) => (* What about INTval &c? *)
		normList es [] (fn res=>N.APPexp((v,vl), res, ext1, loc)) (* ext was default; discard it *)
	      | _ =>
		let
		    val g = tempName()
		    val eloc = getLoc e
		in
		    N.LETexp((g,eloc),
			     normExpr e,
			     normList es [] (fn res=>N.APPexp((g,eloc), res, LOCAL, loc)),
				      eloc)
		end
	end
      | NEWexp (class,es,loc) => normList  es [] (fn res=>N.NEWexp(class,res,loc))
      | SUPERMAKERexp (es, loc) => normList es [] (fn res=>N.SUPERMAKERexp(res, loc))
      | INVOKEexp (ob, mname, es, loc) =>
        let in case ob of
	    VALexp (VARval (v,ext,vl), _) => normList es [] (fn res=>N.INVOKEexp((v,vl),mname,res,loc))
	  | _ =>
            let
		val x = tempName()
	    in
		N.LETexp ((x,loc), normExpr ob,
			 normExpr (INVOKEexp(valExp x loc ,mname,es, loc)), loc)
		(* CORRECT NORMALISATION??? *)
	    end
	end
      | UPDATEexp (X as (x,xl), e, loc) =>
	if atomic e then N.UPDATEexp (X, normAtom e, loc)
	else
	    let
		val y = tempName()
	    in
		N.LETexp ((y,xl), normExpr e, N.UPDATEexp(X, N.VARval (y, LOCAL, loc), loc), loc)
	    end
      | GETexp (obj,x,loc) =>
	let in case
		obj of
		VALexp (VARval (v,_,l), _) => N.GETexp ((v,l), x, loc)
	      | _ =>
		let
		    val y = tempName()
		in
		    N.LETexp ((y,loc), normExpr obj, N.GETexp((y,loc), x, loc), loc)
		end
	end
      | SGETexp x => N.SGETexp x
      | TYPEDexp (e,t, loc) => N.TYPEDexp(normExpr e, t, loc)
      | COERCEexp (e,t, loc) => N.COERCEexp(normExpr e, t, loc)
      | ASSERTexp (e, as1, as2, loc) => N.ASSERTexp(normExpr e, as1, as2, loc)
      | MATCHexp (e, mrs, loc) => (* have to lift stuff out of e *)
	    let
		fun normMrule x (MATCHrule(c,p,a,e',loc')) =
		    (case a of _ => N.MATCHrule(c,p,a,normExpr e', loc')
		(*  | SOME (l,a') => (loc', MATCHrule (c,p,a, (fst E', LETexp (a', x, normExpr E'))))*)
		    )
		  | normMrule x (OOMATCHrule(pat, e', loc')) =
		    N.OOMATCHrule(pat, normExpr e', loc')
	    in case e of
		   VALexp(VARval (x,ext,l), xloc) =>
		      N.MATCHexp ((x,l), map (normMrule (valExp x xloc)) mrs, loc)
		 | VALexp _ =>
                   Util.exit "Invalid match expression: trying to assign to non-variable"
		 | _ =>
		   let
		       val x = tempName()
		       val eloc = getLoc e
		   in
		       N.LETexp((x,eloc), normExpr e,
			      N.MATCHexp((x,eloc),
				       map (normMrule (valExp x loc)) mrs, eloc),eloc)
		   end
	    end
      | LETexp (x, d, e, loc) => N.LETexp(x, normExpr d, normExpr e, loc)
      | LAMexp _ => normExpr (Lambda.addlam expr)


fun float_lets expr =
    let fun float_mrule_lets (N.MATCHrule (v, args, diam, e, mloc)) =
	    N.MATCHrule (v,args, diam, float_lets e, mloc)
	  | float_mrule_lets (N.OOMATCHrule (pat, e, mloc)) =
	    N.OOMATCHrule (pat, float_lets e, mloc)
    in
    case expr of
	N.VALexp (v,loc) => N.VALexp(float_lets_v v, loc)
      | N.UNARYexp _ => expr
      | N.BINexp _ => expr
      | N.IFexp (test, e, f, loc)
	=> N.IFexp (test, float_lets e, float_lets f, loc)
      | N.MATCHexp (v, rules, loc) => N.MATCHexp (v, map float_mrule_lets rules, loc)
      | N.LETexp (v, e, e',loc) =>
	let val e1 = float_lets e in
	    case e1 of
		N.LETexp (w, f, f',loc1) =>
		      N.LETexp (w, f, float_lets (N.LETexp (v, f', e',loc)), loc1)
	      | _ => N.LETexp(v, float_lets e1, float_lets e',loc)
	end
      | N.APPexp _ => expr
      | N.CONexp _ => expr
      | N.TYPEDexp (e, ty, loc) => N.TYPEDexp (float_lets e, ty,loc)
      | N.COERCEexp (e, t,loc) => N.COERCEexp (float_lets e, t, loc)
      | N.INVOKEexp _ => expr
      | N.NEWexp _ => expr
      | N.SUPERMAKERexp _ => expr
      | N.UPDATEexp _ => expr
      | N.GETexp _ => expr
      | N.SGETexp _ => expr
      | N.ASSERTexp _ => expr
    end

and float_lets_v v = v  (* FIX: get rid of this *)



fun getVal e =
    case e of
	VALexp (v, _) => v
      | _ => Util.error (getLoc e) "Supermaker arguments can only be values"

fun normFunDef globals (FUNdef(fname, args, inst, body, loc)) =
    let
	val () = resetNames globals
(*	val () = resetTempNames()*)  (* Why not? *)
	val () = app addName (Util.getArgNames args)
    in
	N.FUNdef(fname, args, inst, (float_lets o normExpr o uniq) body, loc)
    end


(* During normExpr we may have generated some lambda-lifted functions.  The bodies
   of these may contain further lamexps,  so we now have to lift them as well.  We
   keep doing this till there are no new functions. *)

fun normLams globals acc =
    let
	val lams = Lambda.getLams ()
    in
	case lams of
	    [] => acc
	  | _ =>
	    let
		val lams' = map (normFunDef globals) lams
	    in
		normLams globals (lams'@acc)
	    end
    end

fun normFunBlock globals (FUNblock b) = N.FUNblock (map (normFunDef globals) b)

fun normClassDef (CLASSdef(cname,super,intfs,vals,funs)) =
    N.CLASSdef(cname, super, intfs, vals, map (normFunDef [FUNblock funs]) funs)

fun normProg (PROG(typeDecs, valDecs, classDefs, funBlocks)) =
    let
	val classDefs' = map normClassDef classDefs
	val funBlocks' = map (normFunBlock funBlocks) funBlocks
	val lams = normLams funBlocks []
	val lams' = map (fn l => N.FUNblock [l]) lams
	val np =
	    N.PROG(typeDecs, valDecs, classDefs',  lams'@funBlocks')
    in
	NAsyntfn.fixLocalFuns np
    end



fun uprog (PROG(typeDecs, valDecs, classDefs, funBlocks)) = (* IS THIS EVER USED ?? *)
let fun ufun (funDef as FUNdef(fname, args, inst, body, loc)) =
	FUNdef(fname, args, inst, uniq body, loc)
in
    N.PROG(typeDecs, valDecs, map normClassDef classDefs, map (normFunBlock funBlocks) funBlocks)
end

end (* local *)

