local
open GrailAbsyn
structure P=Polyhash
in

(* printing and error functions *)

fun ierror s = 
    let
	val () = TextIO.output(TextIO.stdErr,"Grail error " ^ s ^ "\n")
    in
	OS.Process.exit OS.Process.success
    end

exception GraphException of string
exception MissingNode

fun GraphError s = 
    ierror ("[FlowGraph.sml]: " ^ s)

(* top level function *)

fun flowGraph (MBODY(lets, funs, mresult)) =
let
    val fg = P.mkPolyTable(30, MissingNode)
    (* Hash table mapping function names f to pairs (v_in, v_out),
     where v_in is a list of names of functions which can call
     f and v_out is a list of name of functions which f can call. *)

    fun pnext p acc = 
	case p of 
	    FUNres (f,_) => f::acc
	  | _ => acc

    fun next r =
	case r of PRIMres p => pnext p []
		| CHOICEres (_,_,_,p1,p2) => (pnext p1 (pnext  p2 []))
		| CASEres (_,_,_,l) => map (fn (_,f,_) => f) l

    fun insPred f node = 
	case P.peek fg node of
	    NONE => P.insert fg (node, ([f], []))
	  | SOME (v_in, v_out) => P.insert fg (node, (f::v_in, v_out))


    fun insertNode (FDEC(fname,_,FUNbody(_,result))) =
	let 
	    val v_out = next result
	    val () = app (insPred fname) v_out
	in
	    case P.peek fg fname of
		NONE => P.insert fg (fname, ([], v_out))
	      | SOME (v_in, []) => P.insert fg (fname, (v_in,v_out))
	      | _ => raise GraphError ("Trying to insert duplicate node [" ^ fname ^ "] in flow graph")
	end


    val root = FDEC("*", [],FUNbody([],mresult))  (* fake fundec for root *)
    val () = app insertNode (root::funs)
    (* FIX: what happens if funs is empty? *)
in
    fg
end

fun deleteUnreachableNodes funs htable =
    let
	fun reachable (f,(v_in, v_out)) = v_in <> []
	val goodnodes = P.filter reachable htable (* may throw away root *)
	val goodfuns = List.filter (fn FDEC(name,_,_) => false) funs
    in
	(goodfuns, goodnodes)
    end

fun printGraph fg = 
let
    fun printNode (f, (v_in, v_out)) =
	(
	print (f ^ " in:  ");
	app (fn s => print (s^" ")) v_in;
	print "\n";
	print (f ^ " out: ");
	app (fn s => print (s^" ")) v_out;
	print "\n"
	)

    fun pmerge (f,(v_in, _)) = if length v_in > 1 then print (f ^ " ") else ()
    val () = print "Merge points: "
    val () = P.apply pmerge fg
    val () = print "\n"
in
    P.apply printNode fg
end

(* ------------------------ Dominators ------------------------ *)

(* Calculate immediate dominators of nodes in flow-graph.  We use 
   the Lengauer-Tarjan algorithm as presented on pp 441-443 of
   Appel.  The current implementation is highly imperative and
   would benefit from rewriting.  It may well be non-bug-free. *)


open Polyhash

fun dominators fg = 
    let
	val numNodes = numItems fg

	fun newtable s = mkPolyTable (30, GraphException ("Missing item in " ^ s))

	val dfnum    = newtable "DFS"
	val semi     = newtable "semi"
	val ancestor = newtable "ancestor"
	val idom     = newtable "idom"
	val samedom  = newtable "samedom"
	val parent   = newtable "parent"
	val vertex   = newtable "vertex"
	val best     = newtable "best"

	val bucket   = Polyhash.map (fn _ => []) fg

	val N = ref 0

	fun mkDFS p n = (* Depth-first spanning tree *)
	    if peek dfnum n = NONE then 
		(
(*		 print ("inserting " ^ n ^ " at " ^ Int.toString (!N) ^ "\n");*)
		 insert dfnum (n, !N);
		 insert vertex (!N, n);
		 insert parent (n,p);
		 N := !N+1;
		 let val (_,v_out) = find fg n
		 in 
		     app (mkDFS n) v_out 
		 end
		)
	    else ()
		  
	fun checkDFS () =  (* for debugging *)
	    if !N <> numNodes 
	    then print "WARNING: spanning tree wrong size?" 
	    else ();

(*	val () = print ("N -> " ^ Int.toString (!N) ^ "\n")*)


(* Fast, but doesn't work

	fun awls v = 
	    let 
		val () = print ("awls: " ^ v ^ "\n")
		val a' = case peek ancestor v of 
			     NONE => v
			   | SOME a0 => a0
		val () = case peek ancestor a' of 
			     SOME a => 
			     let
				 val () = print ("ANCESTOR " ^ a ^" \n")
				 val b = awls a 
				 val () = insert ancestor (v, find ancestor a)
			     in
				 if find dfnum (find semi b) < find dfnum (find semi (find best v))
				 then insert best (v,b) else ()
			     end
			   | NONE => ()
		val () = print ("awls-\n")
	    in
		find best v
	    end

	fun link p n = (insert ancestor (n,p); insert best (n,n))

*)

        (* Slower, but does work *)

	fun awls' u v = (* ancestor with least semidominator *)
	    case peek ancestor v of 
		SOME a  => if find dfnum (find semi v) < find dfnum (find semi u) 
			   then awls' v a else awls' u a 
	      | NONE => u

	fun awls v = awls' v v

	fun link p n = insert ancestor (n,p)


	fun semidominator n s l = 
	    case l of 
		[] => s
	      | v::t => let
		    val s' = 
			if find dfnum v <= find dfnum n then v
			else find semi (awls v)
		    val s'' = if find dfnum s' < find dfnum s 
			      then s' else s
		in
		    semidominator n s'' t
		end

	fun provisionalDominator p v =
	    let 
		val y = awls v
	    in 
		if find semi y = find semi v then 
		    insert idom (v,p)   (* we know the dominator for certain *)
		else
		    insert samedom(v,y) (* wait until y's dominator is known *)
	    end
			
	fun updatebucket s n = 
	    let val v = find bucket s
	    in insert bucket (s, n::v)
	    end
	    
	fun phase1 i = 
	    if i <= 0 then ()
	    else
	    let 
		val n = find vertex i
		val p = find parent n
		val (v_in, _) = find fg n
		val s = semidominator n p v_in
	    in
		insert semi (n, s);
		updatebucket s n;
		link p n;
		app (provisionalDominator p) (find bucket p);
		insert bucket (p,[]);
		phase1 (i-1)
	    end
	    

	fun fixDominators i = (* Now complete the calculations deferred in the previous phase *)
	    if i >= numNodes then ()
	    else 
		let 
		    val n = find vertex i
		    val () = case peek samedom n of 
				 NONE => ()
			       | SOME q => insert idom (n, q)
		in
		    fixDominators (i+1)
		end
    in
	mkDFS "" "*"; (* FIX: check root entry *)
        checkDFS ();
	phase1 (numNodes - 1);
	fixDominators 1;
	idom
    end		 


fun pidom t = 
    Polyhash.apply (fn (a,b) => print ("** " ^ a ^ " -> " ^ b ^ "\n")) t

fun invertDoms h = 
    let
	val i = mkPolyTable (30, GraphException "inversion")
	fun f (node,dom) = 
	    case peek i dom of 
		NONE => insert i (dom, [node])
	      | SOME l => insert i (dom, node::l)
    in
	apply f h;
	remove i "*" handle _ => [];   (* Don't want to know what root dominates (really?) *)
	listItems i
    end
			
fun id s = s 

(* top level function; called from Compile.compile_mdef *) 
fun makeDefs mbody =
    let
	val fg = flowGraph mbody
	val dom = dominators fg
	val mergepoints = List.map #1 (List.filter (fn (k,(v_in,_)) => length v_in > 1) (listItems fg))
	val mergePoint_def = "isMergePoint_def: \"isMergePoint f == f \\<in> {" 
			     ^ listToString id ", " mergepoints
			     ^ "}\"\n"
	fun makecase (f,l) = "if f = " ^ f ^ " then [" ^ listToString id ", " (rev l) ^ "] else"
	val inv_dom = List.map makecase (invertDoms dom)
	val dom_def = "dominates_def:\n\"dominates f == "
		      ^ (case inv_dom of 
			    [] => "[]\""
			  | _ =>
			    "("
			     ^ listToString id "\n                 " inv_dom
			     ^ " [])\"")
    in
	print mergePoint_def;
	print "\n";
	print dom_def;
	print "\n\n"
    end
				   
end
				   




(* ================================ Not needed (maybe) ================================ *)


(*
datatype dfsTree = Empty | Node of string * int * dfsTree list

fun none _ = false
fun update f s x = if x = s then true else f x

fun makeDFSTree fg = (* make depth-first spanning tree from flow graph *)
    let 
	fun doNode fname n marked = 
	    let 
		val out = case P.peek fg fname of 
			      NONE => GraphError ("Missing node [" 
						  ^ fname ^ "] in flow graph")
			    | SOME (_, out) => out

		val (nxt,n',marked') = doNxt n (update marked fname) out []
	    in
		(Node (fname, n, rev nxt), n', marked')
	    end

	and doNxt n marked todo done = 
	    case todo of [] => (done, n, marked)
		    | f::t => 
		      if marked f then doNxt n marked t done
		      else 
			  let
			      val () = print ("Marking node " ^ f ^ "(" ^ Int.toString n ^ ")\n")
			      val (branch, n', marked') = doNode f (n+1) (update marked f)
			  in
			      doNxt n' marked' t (branch::done)
			  end
    in
	doNode "*" 0 none 
    end
end

fun printTree indent Empty = print (indent ^ "$")
  | printTree indent (Node(f,n,nxt)) = 
    (print (indent ^ f ^ ": " ^ Int.toString n ^ "\n");
     app (printTree (indent ^ " ")) nxt)
*)
