(* Utility functions for monomorphsation, mostly to do with rearranging
   order of objects in monomorphised code. *)

open Normsyn Util

fun mu_error s = Util.ierror ("[MonoUtil.sml]: " ^ s)

fun trimSuffix s = (* Get rid of _[0-9]* at end of name *)
    let
	fun dropNum [] = []
	  | dropNum (s as c::t) =
	    if Char.isDigit c then dropNum t
	    else s
	val s' = dropNum (List.rev (explode s))
    in
	case s' of (#"_"::t) => implode (List.rev t)
		 | _ => s
    end

(* ---- Rearrange the monomorphised program into a reasonable order ---- *)


fun sort cmp l =
    let
	fun ins x [] = [x]
	  | ins x (h::t) = if cmp (x, h) <> GREATER then x::h::t else h::(ins x t)
    in
	case l of [] => []
		| h::t => ins h (sort cmp t)
    end

fun funName (FUNdef((fname,_), _,_,_,_)) = fname
fun decName (TYPEdec(_,(dname,_),_,_)) = dname

fun getArgnames l = List.mapPartial (fn UNITvar => NONE | VAR(s,_) => SOME s) l

fun precedes a b l = (* a precedes b in l *)
   let
       fun p m =
	   case m of
	       [] => LESS (*mu_error ("'precedes' failed on " ^ a ^ " and " ^ b)*)
	     | h::t =>
	       if h=a then LESS
	       else if h=b then GREATER
	       else p t
   in
       if a=b then EQUAL
       else p l
   end

(* The next bit of code tries to compare two monomorphised names according
   to the position of the original names in the list l.  We have to be careful
   in case the user's defined a function name which looks like a monomorphised
   name.  This is why we don't trim suffixes if the name is in the list of
   original names.  In fact, there'll probably be trouble if the user defines
   eg f and f_1. Maybe we should change the suffixes to preclude this. *)

fun compare_basenames f g l =
    let
	val f0 = if member f l then f else trimSuffix f
	val g0 = if member g l then g else trimSuffix g
    in
	if f0 = g0 then String.compare (f,g)
	else precedes f0 g0 l
    end

fun reorder new nameOf originalNames =
    let
	fun precedes (t1, t2) = compare_basenames (nameOf t1) (nameOf t2) originalNames
    in
	sort precedes new
    end


fun reorderDecs newDecs oldDecs =
    reorder newDecs decName (map decName oldDecs)

fun reorderFuns monofuns polyfuns =
    reorder monofuns funName (map funName polyfuns)

local
    fun vorder p =
	case p of
	    (VALdec _, CLASSdec _) => LESS
	  | (CLASSdec _, VALdec _) => GREATER
	  | (VALdec ((n1,_),_,_), VALdec ((n2,_),_,_)) => String.compare(n1,n2)
	  | (CLASSdec ((n1,_),_,_,_), CLASSdec ((n2,_),_,_,_)) => String.compare(n1,n2)
in

fun reorderValdecs l = sort vorder l

end


(* ---------------- Rebuilding funblocks ---------------- *)

(* We create minimal funblocks for mutually recursive functions,
   ignoring the original block structure.  Suppose you have

		let f x = ... g x ...
                and g y = ... f y ...
		and h n = n+1

  If we monomorphise whole blocks at a time we could end up with two
  (identical) versions of h in different funblocks. Presumably this
  wouldn't actually cause problems since calls to h would resolve to
  calls to one of the new versions; however it still seems a little
  inelegant.  To avoid this we adopt the strategy mentioned above.
  Since the program's monomorphic we could safely put each function in
  a separate funblock; however, this confuses other people's (ie
  Michal's & Steffen's) programs. We could also create a single giant
  funblock, but this makes typechecking take forever. *)

local
    structure S = Splayset
    structure P = Polyhash
    exception missing

    val C: (string, string list) P.hash_table     (* f -> fns reachable from f *)
      = P.mkTable(P.hash, op=)(20, missing)

    val F: (string, Annotation FunDef) P.hash_table    (* fname -> fdef *)
      = P.mkTable(P.hash, op=)(20, missing)

    fun getFdef fname =
	P.find F fname
	handle missing => mu_error ("Couldn't find fundef for '" ^ fname ^ "'")

    fun getFdef' fname = P.find F fname

    fun isFun x =
	case P.peek F x of SOME _ => true | NONE => false

    fun getCalls f =
	P.find C f
	handle missing => mu_error ("Couldn't find '" ^ f ^ "' while rebuilding funblocks")


    infix 4 butnot
    fun a butnot b = a andalso (not b)

    infix 5 calls
    fun f calls g = member g (getCalls f)

    infix 5 :?
    fun x :? l = if member x l then l else x::l


    (* The code below probably does a lot of repetition;
       I'm sure it could be improved substantially. *)

    fun reachable e acc done argnames funnames extra =
	(* names of functions reachable in e from f *)
				(* FIX: dangerous for hofs *)
    let
	fun reachableExp e acc done =
	case e of
	    VALexp(VARval(x,GLOBAL,loc),_) =>
	    if member x argnames then (acc, done)
	    else
		let val acc' = x :? acc
		    val FUNdef (_,args,_,xbody,_) = (getFdef' x
			handle missing => Util.error loc ("Missing function"))
		in
		    if S.member (done, x) then
			(acc', done)
		    else
			reachable xbody acc' (S.add(done, x)) [] funnames extra
			(* FIX: check this *)
		end
	  | VALexp(VARval(x,_,_),_) => (acc, done)
	  | APPexp ((x,xloc),eargs,GLOBAL,loc) =>
	    if member x argnames then (acc, done)
	    else
		let
		    val FUNdef (_,args,_,xbody,_) = getFdef' x
			handle _ => Util.error xloc ("Can't find anything to apply ["
						     ^ x ^ "].  "
						     ^ "Is the enclosing function ever used?")
		    fun q l acc =
			case l of [] => acc
		      | VARval(h,GLOBAL,_)::t => if member h funnames
					     orelse member h extra then q t ( h :? acc) else q t acc
		      | _::t => q t acc
		    val acc'' = q eargs (x:?acc)
		in
		    if S.member (done, x) then
			(acc'', done)
		    else
			reachable xbody acc'' (S.add(done, x)) (getArgnames args) funnames extra
		end
	  | APPexp (x,_,_,_) => (acc, done)
	  | IFexp (_, e1, e2, _) =>
	    let
		val (acc', done') = reachableExp e1 acc done
	    in
		reachableExp e2 acc' done'
	    end
	  | MATCHexp (x, rules, l) => do_rules rules acc done
	  | LETexp (v, e, e',l) =>
	    let
		val (acc', done') = reachableExp e acc done
	    in
		reachableExp e' acc' done'
	    end
	  | ASSERTexp (e,_,_,_) => reachableExp e acc done
	  | TYPEDexp  (e,_,_)   => reachableExp e acc done
	  | COERCEexp (e,_,_)   => reachableExp e acc done  (* what if we've got fns in constrs ? *)
	  | _ => (acc, done)  (* no subexpressions *)

    and rule_reachable r acc done =
	case r of
	    MATCHrule (_,_,_,e,_) => reachableExp e acc done
	  | OOMATCHrule (_,e,_) => reachableExp e acc done

    and do_rules rules acc done =
	case rules of
	    [] => (acc, done)
	  | h::t => let
		val (acc',done') = rule_reachable h acc done
	    in
		do_rules t acc' done'
	    end
    in
	reachableExp e acc done
    end
(*    and do_rules rules acc done = foldl (fn (r, (a, d)) => rule_reachable r a d)
					 (acc,done) rules
*)

    fun getReachable fname e argnames funnames extra  =
	let
	    val done = S.singleton String.compare fname
	    val r = fst (reachable e [] done argnames funnames extra)
	    val () = debugPrint ("Reachable: " ^ fname ^" -> " ^ listToString id " " r ^ "\n")
	in
	    r
	end

    fun mutrec fdef = (* functions which are (maybe indirectly) mutually recursive with fdef *)
	let
	    val f = funName fdef
	    val c = getCalls f
	in
	    f :? (List.filter (fn g => g calls f) c)
	end

in

fun makeFunBlocks fdefs all extra =
    let
	fun id x = x

	val origFnames = map funName all

	val () = app (fn fd => P.insert F (funName fd, fd)) fdefs

	fun doFundef (FUNdef((fname,_),args,_,body,_)) =
	    P.insert C (fname, getReachable fname body (getArgnames args) origFnames extra)

	val () = app doFundef fdefs

	fun cmpBlks l (b1, b2) =
	    let
		val f = hd b1
		val g = hd b2
	    in
		if f = g then EQUAL
		else if f calls g then GREATER
		else if g calls f then LESS
		else compare_basenames f g l
	    end

       (* We're going to use this shortly to rearrange the funblocks.  We have to
          do it this way in case the user's done something like "let f = ... and
          g = ... and ... where f isn't called by any of the following functions
          in the block (but calls g say):  we mustn't output "let f = ...;
          let g = ... and ..." because g will be out of scope in the defn of f.
        *)
        (* [The keywords in that comment really confuse the xemacs sml-mode indentation.] *)




	fun makeBlocks l done acc =
	    case l of
		[] => acc
	      | h::t =>
		if S.member (done, funName h)
		then
		    makeBlocks t done acc
		else
		    let
			val blk = mutrec h
			val oblk = reorder blk id origFnames
			(* Arrange contents of block to conform with original order *)
		    in
			makeBlocks t (S.addList(done,oblk)) (oblk::acc)
		    end

	val nameBlocks = makeBlocks fdefs (S.empty (String.compare)) []

	val () = app (fn b => debugPrint ("[" ^ listToString id ", " b ^ "]\n")) nameBlocks

	val allNames = map funName fdefs
	fun fncmp (f,g) = if f=g then EQUAL else if f calls g then GREATER else LESS
	val sortedNames = sort fncmp allNames



        (* Rearrange the blocks into a sensible order *)
	val ordBlocks = sort (cmpBlks sortedNames) nameBlocks
	val () = app (fn b => debugPrint ("[" ^ listToString id ", " b ^ "]\n")) ordBlocks

	fun getBlock b = FUNblock (map (fn fname => getFdef fname) b)

	val () = P.apply (fn (s,l) => debugPrint (s ^ " calls " ^ listToString id ", " l ^ "\n")) C
   in
      map getBlock ordBlocks
   end
end

