(* Stackdepth -- Peter Sestoft   1997-07-03, 1997-07-27
              -- Peter Bertelsen 1997-08-01, 1997-08-07, 1997-10-10

Computes the maximal stack depth of a well-formed JVM bytecode
sequence.

Restrictions on local subroutines (implemented by jsr and ret),
enforced by the code below:

  - a subroutine immediately stores the return address (from the stack)
    in a local variable;
  - the subroutine does not subsequently modify that variable (not even
    by calling other subroutines);
  - that variable must be used in all the subroutine's ret instructions;
  - subroutines can call subroutines, but they cannot be recursive (neither
    directly nor indirectly);
  - subroutines can be entered and exited only by means of jsr and ret;
  - hence every instruction belongs either to the main program or to a
    single subroutine;
  - this is the `color' of the instruction: NONE if it belongs to the
    main program, SOME lbl if it belongs to the subroutine that begins
    with label lbl;
  - a subroutine may fail to return by entering an infinite loop, by
    throwing an exception, or by terminating the containing method.
*)

(* This belongs elsewhere ******************************* *)

exception Impossible of string

(* End of This belongs elsewhere ******************************* *)

fun bug s = raise Impossible ("Stackdepth." ^ s)

local
    open Jvmtype
in
    fun fieldDelta t = width t

    (* Compute difference between method return type size and
       arguments type size *)

    fun methodDelta (argTs, retT) =
	let val argsSize = List.foldl (fn (t, tot) => width t + tot) 0 argTs
	in
	    case retT of
		NONE   => ~argsSize
	      | SOME t => width t - argsSize
	end
end


(* Manipulating sets of modified local variables *)

abstype lvarset = Lvarset of Intset.intset
with
    open Intset
    val emptyset = Lvarset empty

    fun addtoset1 lvar (Lvarset s) =
	Lvarset (add(s, Localvar.toInt lvar))

    fun addtoset2 lvar (Lvarset s) =
	let val i = Localvar.toInt lvar
	in
	    Lvarset (add(add(s, i), i+1))
	end

    fun inset lvar (Lvarset s) =
	Intset.member(s, Localvar.toInt lvar)

    fun union (Lvarset a) (Lvarset b) =
	Lvarset(Intset.union(a, b))
end

(* What is known about the code at a given label:
    PENDING     - label not reached
    LBLRESOLVED - label reached, stack depth and color resolved
    SRPENDING   - subroutine called, not returned from, still resolving
    SRPARTIAL   - subroutine called and returned from, still resolving
    SRRESOLVED  - subroutine called and fully resolved
*)

datatype lblinfo =
    PENDING of Bytecode.jvm_instr list
  | LBLRESOLVED of { dep : int, col : Label.label option }
  | SRPENDING
  | SRPARTIAL of { after : int, modifref : lvarset ref }
  | SRRESOLVED of { dep : int, after : int option,
		    modif : lvarset, maxdep : int }

fun maxdepth code hdlrs =
    let open Bytecode

	(* Create a table with information for all labels in the program *)

	val exnFind = Fail "Stackdepth.maxdepth: undefined label"
	val labelinfo = Polyhash.mkPolyTable(1021, exnFind)
	fun update lbl info = Polyhash.insert labelinfo (lbl, info)
	fun lookup lbl = Polyhash.find labelinfo lbl
	fun buildpending [] = ()
	  | buildpending (Jlabel lbl :: rest) =
	    (update lbl (PENDING rest); buildpending rest)
	  | buildpending (_ :: rest) = buildpending rest
	val _ = buildpending code

	(* Record the stack depth at return from local subroutine lbl *)

	fun jsrupdate lbl depth modif =
	    case lookup lbl of
		SRPARTIAL {after, modifref} =>
		    if depth <> after then
			raise Fail "Inconsistent stack depths at ret"
		    else
			modifref := union (!modifref) modif
	      | SRPENDING =>
		    update lbl (SRPARTIAL {after=depth, modifref=ref modif})
	      | _ => bug "jsrupdate"

	fun color NONE = NONE
	  | color (SOME(_, lbl)) = SOME lbl

	(* Record the stack depth at an ordinary label *)

	fun resolve depth srCol modif (lbl, maxdepth) =
	    case lookup lbl of
	        PENDING code =>
		    (update lbl (LBLRESOLVED {dep=depth, col=color srCol});
		     finddepth code srCol modif depth maxdepth)
	      | LBLRESOLVED {dep, col} =>
		    if col <> color srCol then
			raise Fail "Inconsistent subroutine colors at label"
		    else if depth <> dep then
			raise Fail "Inconsistent stack depths at label"
		    else maxdepth
	      | _ => raise Fail "Subroutine label used as ordinary label"

	(* Get the stack depth after subroutine lbl; resolve if necessary *)

	and jsrlookup depth lbl =
	    case lookup lbl of
		PENDING [] => raise Fail "No code in subroutine"
	      | PENDING (Jastore lvar :: rest) =>
		    let val _ = update lbl SRPENDING
			val maxdep = finddepth rest (SOME (lvar, lbl))
			                       emptyset (depth-1) depth
			val (after, modif0) =
			    case lookup lbl of
				SRPENDING =>		(* no reachable Jret *)
				    (NONE, emptyset)
			      | SRPARTIAL {after, modifref} =>
				    (SOME after, !modifref)
			      | _ => bug "jsrlookup"
			val modif = addtoset1 lvar modif0
		    in
			if inset lvar modif0 then
			    raise Fail "Subroutine overwrites return address"
			else
			    (update lbl (SRRESOLVED {dep=depth, after=after,
						     modif=modif,
						     maxdep=maxdep});
			     (after, modif, maxdep))
		    end
	      | PENDING _ => raise Fail "Subroutine should start with Jastore"
	      | LBLRESOLVED _ =>
		    raise Fail "Ordinary label used as subroutine label"
	      | SRRESOLVED { dep, after, modif, maxdep } =>
		    if depth <> dep then
			raise Fail "Inconsistent stack depths at subroutine"
		    else
			(after, modif, maxdep)
	      | _       => raise Fail "Subroutine call loop"

        (* Find maxdepth by code execution; resolve labels and subroutines *)

	and finddepth [] srCol modif depth maxdepth = maxdepth
	  | finddepth (code as ins1 :: rest) srCol modif depth maxdepth =
	    let fun finddepth1 modif delta =
		    let val depth = depth + delta
			val maxdepth = Int.max(depth, maxdepth)
		    in
			finddepth rest srCol modif depth maxdepth
		    end
		fun finddepth2 modif delta lbl =
		    let val depth = depth + delta
			val maxdepth = Int.max(depth, maxdepth)
		    in
			finddepth rest srCol modif depth
		                 (resolve depth srCol modif (lbl, maxdepth))
		    end
	    in
	    case ins1 of
		 Jathrow   => maxdepth
	       | Jreturn   => maxdepth
	       | Jgoto lbl => resolve depth srCol modif (lbl, maxdepth)
	       | Jlookupswitch {default, cases} =>
		     let val depth' = depth - 1

			 fun resolve' ((_, lbl), maxdepth) =
			     resolve depth' srCol modif (lbl, maxdepth)

			 val maxdepth' = List.foldl resolve' maxdepth cases
		     in
			 resolve depth' srCol modif (default, maxdepth')
		     end
	       | Jtableswitch {default, targets, ...} =>
		     let val depth = depth - 1
			 val maxdepth' =
			     Vector.foldl (resolve depth srCol modif)
			                  maxdepth targets
		     in
			 resolve depth srCol modif (default, maxdepth')
		     end
	       | Jlabel lbl => resolve depth srCol modif (lbl, maxdepth)
	       | Jjsr lbl =>
		     let val (afterOpt, modif', maxdepth') =
			     jsrlookup (depth+1) lbl
			 val modif'' = union modif modif'
			 val maxdepth'' = Int.max(maxdepth, maxdepth')
		     in
			 case afterOpt of
			     SOME after =>
				 finddepth rest srCol modif'' after maxdepth''
			   | NONE => maxdepth''
		     end
	       | Jret lvar =>
		     (case srCol of
			  NONE => raise Fail "Local ret not within subroutine"
			| SOME (lvar', lbl) =>
			      if lvar <> lvar' then
				  raise Fail "Wrong lvar in local ret"
			      else
				  (jsrupdate lbl depth modif;
				   maxdepth))
	       | Jastore j      => finddepth1 (addtoset1 j modif) ~1
	       | Jdstore j      => finddepth1 (addtoset2 j modif) ~2
	       | Jfstore j      => finddepth1 (addtoset1 j modif) ~1
	       | Jistore j      => finddepth1 (addtoset1 j modif) ~1
	       | Jlstore j      => finddepth1 (addtoset2 j modif) ~2
	       | Jsconst _      => finddepth1 modif 1
	       | Jaaload        => finddepth1 modif ~1
	       | Jaastore       => finddepth1 modif ~3
	       | Jaconst_null   => finddepth1 modif 1
	       | Jaload j       => finddepth1 modif 1
	       | Jarraylength   => finddepth1 modif 0
	       | Jbaload        => finddepth1 modif ~1
	       | Jbastore       => finddepth1 modif ~3
	       | Jcaload        => finddepth1 modif ~1
	       | Jcastore       => finddepth1 modif ~3
	       | Jcheckcast i   => finddepth1 modif 0
	       | Jclassconst _  => finddepth1 modif 1
	       | Jd2f           => finddepth1 modif ~1
	       | Jd2i           => finddepth1 modif ~1
	       | Jd2l           => finddepth1 modif 0
	       | Jdadd          => finddepth1 modif ~2
	       | Jdaload        => finddepth1 modif 0
	       | Jdastore       => finddepth1 modif ~4
	       | Jdcmpg         => finddepth1 modif ~3
	       | Jdcmpl         => finddepth1 modif ~3
	       | Jdconst _      => finddepth1 modif 2
	       | Jddiv          => finddepth1 modif ~2
	       | Jdload j       => finddepth1 modif 2
	       | Jdmul          => finddepth1 modif ~2
	       | Jdneg          => finddepth1 modif 0
	       | Jdrem          => finddepth1 modif ~2
	       | Jdsub          => finddepth1 modif ~2
	       | Jdup           => finddepth1 modif 1
	       | Jdup_x1        => finddepth1 modif 1
	       | Jdup_x2        => finddepth1 modif 1
	       | Jdup2          => finddepth1 modif 2
	       | Jdup2_x1       => finddepth1 modif 2
	       | Jdup2_x2       => finddepth1 modif 2
	       | Jf2d           => finddepth1 modif 1
	       | Jf2i           => finddepth1 modif 0
	       | Jf2l           => finddepth1 modif 1
	       | Jfadd          => finddepth1 modif ~1
	       | Jfaload        => finddepth1 modif ~1
	       | Jfastore       => finddepth1 modif ~3
	       | Jfcmpg         => finddepth1 modif ~1
	       | Jfcmpl         => finddepth1 modif ~1
	       | Jfconst _      => finddepth1 modif 1
	       | Jfdiv          => finddepth1 modif ~1
	       | Jfload j       => finddepth1 modif 1
	       | Jfmul          => finddepth1 modif ~1
	       | Jfneg          => finddepth1 modif 0
	       | Jfrem          => finddepth1 modif ~1
	       | Jfsub          => finddepth1 modif ~1
	       | Jgetfield {ty, ...} =>
		      finddepth1 modif (fieldDelta ty - 1)
	       | Jgetstatic {ty, ...} =>
		      finddepth1 modif (fieldDelta ty)
	       | Ji2b           => finddepth1 modif 0
	       | Ji2c           => finddepth1 modif 0
	       | Ji2d           => finddepth1 modif 1
	       | Ji2f           => finddepth1 modif 0
	       | Ji2l           => finddepth1 modif 1
	       | Ji2s           => finddepth1 modif 0
	       | Jiadd          => finddepth1 modif ~1
	       | Jiaload        => finddepth1 modif ~1
	       | Jiand          => finddepth1 modif ~1
	       | Jiastore       => finddepth1 modif ~3
	       | Jiconst _      => finddepth1 modif 1
	       | Jidiv          => finddepth1 modif ~1
	       | Jif_acmpeq lbl => finddepth2 modif ~2 lbl
	       | Jif_acmpne lbl => finddepth2 modif ~2 lbl
	       | Jif_icmpeq lbl => finddepth2 modif ~2 lbl
	       | Jif_icmpne lbl => finddepth2 modif ~2 lbl
	       | Jif_icmplt lbl => finddepth2 modif ~2 lbl
	       | Jif_icmpge lbl => finddepth2 modif ~2 lbl
	       | Jif_icmpgt lbl => finddepth2 modif ~2 lbl
	       | Jif_icmple lbl => finddepth2 modif ~2 lbl
	       | Jifeq lbl      => finddepth2 modif ~1 lbl
	       | Jifne lbl      => finddepth2 modif ~1 lbl
	       | Jiflt lbl      => finddepth2 modif ~1 lbl
	       | Jifge lbl      => finddepth2 modif ~1 lbl
	       | Jifgt lbl      => finddepth2 modif ~1 lbl
	       | Jifle lbl      => finddepth2 modif ~1 lbl
	       | Jifnonnull lbl => finddepth2 modif ~1 lbl
	       | Jifnull lbl    => finddepth2 modif ~1 lbl
	       | Jiinc args     => finddepth1 modif 0
	       | Jiload j       => finddepth1 modif 1
	       | Jimul          => finddepth1 modif ~1
	       | Jineg          => finddepth1 modif 0
	       | Jinstanceof i  => finddepth1 modif 0
	       | Jinvokeinterface {msig, ...} =>
		      finddepth1 modif (methodDelta msig - 1)
	       | Jinvokespecial {msig, ...} =>
		      finddepth1 modif (methodDelta msig - 1)
	       | Jinvokestatic {msig, ...} =>
		      finddepth1 modif (methodDelta msig)
	       | Jinvokevirtual {msig, ...} =>
		     finddepth1 modif (methodDelta msig - 1)
	       | Jior           => finddepth1 modif ~1
	       | Jirem          => finddepth1 modif ~1
	       | Jishl          => finddepth1 modif ~1
	       | Jishr          => finddepth1 modif ~1
	       | Jisub          => finddepth1 modif ~1
	       | Jiushr         => finddepth1 modif ~1
	       | Jixor          => finddepth1 modif ~1
	       | Jl2d           => finddepth1 modif 0
	       | Jl2f           => finddepth1 modif ~1
	       | Jl2i           => finddepth1 modif ~1
	       | Jladd          => finddepth1 modif ~2
	       | Jlaload        => finddepth1 modif 0
	       | Jland          => finddepth1 modif ~2
	       | Jlastore       => finddepth1 modif ~4
	       | Jlcmp          => finddepth1 modif ~3
	       | Jlconst _      => finddepth1 modif 2
	       | Jldiv          => finddepth1 modif ~2
	       | Jlload j       => finddepth1 modif 2
	       | Jlmul          => finddepth1 modif ~2
	       | Jlneg          => finddepth1 modif 0
	       | Jlor           => finddepth1 modif ~2
	       | Jlrem          => finddepth1 modif ~2
	       | Jlshl          => finddepth1 modif ~2
	       | Jlshr          => finddepth1 modif ~2
	       | Jlsub          => finddepth1 modif ~2
	       | Jlushr         => finddepth1 modif ~2
	       | Jlxor          => finddepth1 modif ~2
	       | Jmonitorenter  => finddepth1 modif ~1
	       | Jmonitorexit   => finddepth1 modif ~1
	       | Jnew i         => finddepth1 modif 1
	       | Jnewarray {dim, ...} => finddepth1 modif (1 - dim)
	       | Jnop           => finddepth1 modif 0
	       | Jpop           => finddepth1 modif ~1
	       | Jpop2          => finddepth1 modif ~2
	       | Jputfield {ty, ...} =>
		     finddepth1 modif (~(fieldDelta ty) - 1)
	       | Jputstatic {ty, ...} =>
		     finddepth1 modif (~(fieldDelta ty))
	       | Jsaload        => finddepth1 modif ~1
	       | Jsastore       => finddepth1 modif ~3
	       | Jswap          => finddepth1 modif 0
	    end
    in
	finddepth code NONE emptyset 0
	          (List.foldl (resolve 1 NONE emptyset) 0 hdlrs)
    end
