local
    open GrailAbsyn  Util
    val calls = ref 0
    val nontail = ref false
    val methcnt = ref 0
    val tailmeths = ref 0
    val nonrecmeths = ref 0
    val badoptmeths = ref 0

    fun report () =
	let
	    val () = debugPrint ((Int.toString (!tailmeths))
				 ^ " of " ^
				 (Int.toString (!methcnt))
				 ^ " tail-recursive.\n")
	    val () = debugPrint ((Int.toString (!badoptmeths))
				 ^ " of " ^
				 (Int.toString (!methcnt))
				 ^ " with tail calls but not tail. rec.\n")
	    val () = debugPrint (Int.toString
				     ((!methcnt) - (!nonrecmeths)
				      - (!tailmeths) - (!badoptmeths))
				 ^ " of " ^
				 (Int.toString (!methcnt))
				 ^ " normal recursive.\n")
	    val () = debugPrint ((Int.toString
				      (!nonrecmeths))
				 ^ " of " ^
				 (Int.toString (!methcnt))
				 ^ " non-recursive.\n")
	in
	    ()
	end

    fun lastPart x = List.last (String.tokens (fn x => x = #".") x)


fun optTailPrimRes name formals fformals r =
let
    fun getChanged [] [] acc = rev acc
      | getChanged ((_,a)::t1) (v::t2) acc =
	let in
	    case v of
		VARval x =>
		if x=a then getChanged t1 t2 acc
		else getChanged t1 t2 ((a,v)::acc)
	      | _ => getChanged t1 t2 ((a,v)::acc)
	end
      | getChanged _ _ acc =
	let val () = print "WARNING: length mismatch in Optimise.optTailPrimRes"
	in rev acc
	end

    fun makeDecs l acc =
	case l of [] => rev acc
		|  (a,v)::rest =>
		   let
		       val t = Normalise.tempName ()
		       val d1 = VALdec (t,VALop v)
		       val d2 = VALdec (a, VALop (VARval t))
		   in
		       makeDecs rest (d2::d1::acc)
		   end

    (* Two reversals here, but the tempnames appear in reverse order
       otherwise,  which is a little confusing. *)

in case r of
       OPres (INVOKESTATICop (MDESC(rty, invname, tys), vals)) =>
       if lastPart invname = name then
	   let
	       (* Simultaneous assignment of vectors: formals := vals  *)
		   (* = non-simultaneous temps:=vals, formals:= temps.     *)
		   (* Allows val v = v and other suboptimalites, but later *)
		   (* optimisations can kill that                          *)

               val changed = getChanged formals vals []
	       val newDecs = makeDecs changed []

               val () = calls := !calls + 1

           in (newDecs, FUNres("f_"^name, (map #2 fformals)))
           end
       else ([], r)

(* NW - 27/4/4 - tail call elimination for virtual methods, i.e.
   method m ... = .... in this#m

   WARNING: counting will/may not work for virtual methods. Perhaps that
   code doesn't really need to be here anymore anyway.

*)

  | OPres (INVOKEVIRTUALop (obj, MDESC(rty, invname, tys), vals)) =>
    if obj = "this" andalso lastPart invname = name then
	let
	  val changed = getChanged formals vals []
	  val newDecs = makeDecs changed []
          val () = calls := !calls + 1
        in
            (newDecs,
             FUNres("f_"^name, (map #2 fformals)))
        end
    else
        ([], r)
  | _ => ([], r)

end

fun optTailRes fname name formals fformals letdecs r =
    case r of
	PRIMres pr =>
	let
	    val (letdecs', func) = optTailPrimRes name formals fformals pr
	in
	    (FUNbody(letdecs @ letdecs', PRIMres(func)), [])
	end
      | CHOICEres(value, test, value', prT, prF) =>
	let
            val (letdecsT, funT) = optTailPrimRes name formals fformals prT
            val (letdecsF, funF) = optTailPrimRes name formals fformals prF
	in
            if letdecsT <> [] orelse letdecsF <> [] then  (* this looks a bit dodgy *)
		let
                    val (nameT', nameF') = (fname ^ ":t", fname ^ ":f")
                    val funT' = FDEC(nameT', fformals,
                                     FUNbody(letdecsT, PRIMres(funT)))
                    val funF' = FDEC(nameF', fformals,
                                     FUNbody(letdecsF, PRIMres(funF)))
		in

                    (FUNbody(letdecs,
                             CHOICEres(value, test, value',
                                       FUNres(nameT', map #2 fformals),
                                       FUNres(nameF', map #2 fformals))),
                     [funT', funF'])
		end
            else
		(* no new let declarations. common case, not in report. *)
		    (FUNbody(letdecs @ letdecsT @ letdecsF,
			     CHOICEres(value, test, value', funT, funF)), [])
	end
      | CASEres _ =>
	let
	    val () = Util.warn (Normsyn.LOC Loc.nilLocation)
			       "No tail-call elimination for Grail cases yet"
	in
	    (FUNbody (letdecs, r), []) (* FIX *)
	end

(* This should be OK ,since we can only do funcalls in cases.
   One of the funcalls might be to a function which only does
   a recursive method call,  but this'll be separately optimised.
   In this case we'll have a jump which could be optimised away,
   but it's not too important since it can't contribute to stack
   overflow for example. *)

(* use of fformal is is a hack, but this gets fixed later on *)

fun optTailFun name formals (FDEC(fname, fformals,
                                  (FUNbody(letdecs, result)))) =
    let
        val (body, funs) =
            optTailRes fname name formals fformals letdecs result

	fun countTailCall (p, n) =
	    case p of
		VALdec(_, INVOKESTATICop(MDESC(_,inv,_),_)) =>
		if lastPart inv = name then n+1 else n
	      | _ => n

	val ntcalls = foldl countTailCall 0 letdecs

	val () = if ntcalls > 0 then nontail := true else ()
    in
        (FDEC(fname, fformals, body)) :: funs
    end

fun optTailMeth (MDEF(flags, rty, name, formals,
                      MBODY(letdecs, funs, result))) =
    let
	val () = (calls := 0; nontail := false)
	val funs' = List.concat (map (optTailFun name formals) funs)
	val () = (methcnt := !methcnt + 1;
		  if !nontail orelse !calls = 0 then ()
		  else tailmeths := !tailmeths + 1)
	val () = if !nontail orelse !calls <> 0 then ()
		 else nonrecmeths := !nonrecmeths + 1
	val () = if !nontail andalso !calls <> 0
		 then badoptmeths := !badoptmeths + 1 else ()
    in
        MDEF(flags, rty, name, formals,
             MBODY(letdecs, funs', result))
    end


(* Perform tail-recursion optimisation on a class. That is, perform
  tail-recursion optimisation on each method in that class. See report *)
fun optTailClass (CDEF(flags, name, super, intfs, fields, meths, layout)) =
    let
	val () = (methcnt := 0; tailmeths := 0;
		  nonrecmeths := 0; badoptmeths := 0)
	val meths' = map optTailMeth meths
	val () = report ()
    in
	CDEF(flags, name, super, intfs, fields, meths', layout)
    end


in

fun optimise ast = optTailClass ast

end
