(* EstimateJumps.sml
   Kenneth MacKenzie,  May 2004
*)

(* The following code calculates upper bounds for jump sizes so that we
   can reduce the number of unnecessary goto_w and jsr_w instructions.
   We calculate jump sizes before the size of some intermediate
   instructions can be determined (the exact size is only known when code is
   emitted and everything's been put in the constant pool).  This give us
   non-exact upper bounds for jump distances,  but since the actual
   distances can only be smaller we can be certain that some jumps don't
   have to be wide.  In fact we only require goto_w and jsr_w for
   jumps of more than about 32767 bytes.  Our estimate can only be wrong
   for jumps which are just over this limit,  but in this case we'll
   end up emitting goto_w when goto would suffice.  In fact,  such large
   jumps don't seem to occur often in practice. It would be possible to
   produce an algorithm which would calculate exact jump distances,  but
   it seems likely that this would require another intermediate
   representation for bytecode,  and would require multiple passes over
   the bytecode in bad cases. *)

(* The main function (shortJumps) returns a list of booleans,  one for
   each goto or jsr instruction in a list of bytecodes.  The boolean
   values are "true" if the corresponding jump can be guaranteed to be
   short,  "false" otherwise. *)

open Bytecode

fun isU1    k =      0 <= k andalso k <= 0xff    (* from Emitcode.sml *)
fun isU2    k =      0 <= k andalso k <= 0xffff  (* from Emitcode.sml *)
fun isByte  k =   ~128 <= k andalso k <= 127     (* from Emitcode.sml *)
fun isShort k = ~32768 <= k andalso k <= 32767   (* from Emitcode.sml *)




(* ---------------- (upper bounds for) sizes of instructions ---------------- *)

fun sizeOfVarAcc j =
    if isU1 j then 2 else 4

fun sizeOfImmVarAcc index =
    let val index' = Localvar.toInt index
    in
	if index' <= 3 then 1       (* single opcode *)
	else
	    if isU1 index' then 2   (* opcode + 1-byte operand *)
	    else 4                  (* wide + opcode + 2-byte operand *)
    end


local
    open Real64
    val bytes_0 = toBytes(fromReal 0.0)
    val bytes_1 = toBytes(fromReal 1.0)
in
fun sizeOfDConst d =
    let val bytes = toBytes d
    in
	if bytes = bytes_0 then 1
	else if bytes = bytes_1 then 1
	else 3  (* ldc2_w _ _ *)
    end
end

local
    open Real32
	     (* val bytes_0 = toBytes(fromReal 0.0) *)
    val bytes_0 = Word8Vector.fromList[0w0, 0w0, 0w0, 0w0]
    val bytes_1 = toBytes(fromReal 1.0)
    val bytes_2 = toBytes(fromReal 2.0)
in
fun sizeOfFConst f =
    let val bytes = toBytes f
    in
	if bytes = bytes_0 then 1         (* fconst_0 *)
	else if bytes = bytes_1 then 1    (* fconst_1 *)
	else if bytes = bytes_2 then 1    (* fconst_2 *)
	else 3   (* pessimistic,  but it's ldc and f isn't in the cp yet *)
    end
end

local
    val zero = Int64.fromInt 0
    val one  = Int64.fromInt 1
in
fun sizeOfLConst l =
    if l = zero then 1     (* lconst_0 *)
    else if l = one then 1 (* lconst_1 *)
    else 3 (* ldc2_w _ _ *)
end


fun sizeOfIConst i =
    (let val i' = Int32.toInt i
     in
	 case i' of
	     ~1 => 1   (* iconst_m1 *)
	   |  0 => 1   (* iconst_0 *)
	   |  1 => 1   (* iconst_1 *)
	   |  2 => 1   (* iconst_2 *)
	   |  3 => 1   (* iconst_3 *)
	   |  4 => 1   (* iconst_4 *)
	   |  5 => 1   (* iconst_5 *)
	   |  _ => if isByte i' then 2 (* bipush *)
		   else 3 (* sipush or ldc_w *)
     end) handle Int32.Int32Overflow _ => 3

fun sizeOfIinc {var, const} =
    let val j    = Localvar.toInt var
    in
	if isU1 j andalso isByte const then 3 (* iinc v c *)
	else 6 (* wide iinc vv cc *)
    end


fun sizeOfLookupswitch {default, cases} = 12 + 8 * (length cases)
    (* pessimistic because (a) there's some padding at the start for
       alignment purposes,  and we don't know exactly how much (0-3 bytes);
       (b) there may be repetitions (eliminated during emission) of cases *)

fun sizeOfTableswitch {default, offset, targets}
  = 16 + 8*(Vector.length targets)
    (* pessimistic because of padding *)

fun sizeOfNewarray {elem, dim} =
    if dim = 1 then
	if Jvmtype.isSimple elem then 2 (* newarray atype *)
	else 3 (* anewarray cp_index *)
    else 4 (* multinewarray cp_index dim *)



fun sizeOf instr =
    (case instr of
	 Jlabel lbl         => 0
       | Jsconst s          => 3 (* might be 2, but s isn't in the cp yet *)
       | Jclassconst _      => 3 (* could be 2 *)

       | Jaload j           => sizeOfImmVarAcc j
       | Jastore j          => sizeOfImmVarAcc j
       | Jdload j           => sizeOfImmVarAcc j
       | Jdstore j          => sizeOfImmVarAcc j
       | Jfload j           => sizeOfImmVarAcc j
       | Jfstore j          => sizeOfImmVarAcc j
       | Jiload j           => sizeOfImmVarAcc j
       | Jistore j          => sizeOfImmVarAcc j
       | Jlload j           => sizeOfImmVarAcc j
       | Jlstore j          => sizeOfImmVarAcc j

       | Jif_acmpeq lbl     => 3 (* opcode + 2-byte offset *)
       | Jif_acmpne lbl     => 3
       | Jif_icmpeq lbl     => 3
       | Jif_icmpne lbl     => 3
       | Jif_icmplt lbl     => 3
       | Jif_icmpge lbl     => 3
       | Jif_icmpgt lbl     => 3
       | Jif_icmple lbl     => 3
       | Jifeq lbl          => 3
       | Jifne lbl          => 3
       | Jiflt lbl          => 3
       | Jifge lbl          => 3
       | Jifgt lbl          => 3
       | Jifle lbl          => 3
       | Jifnull lbl        => 3
       | Jifnonnull lbl     => 3

       | Jcheckcast _       => 3 (* opcode + 2-byte cp index *)
       | Jgetfield  _       => 3
       | Jgetstatic _       => 3
       | Jputfield  _       => 3
       | Jputstatic _       => 3
       | Jnew _             => 3
       | Jinstanceof _      => 3
       | Jinvokespecial   _ => 3
       | Jinvokestatic    _ => 3
       | Jinvokevirtual   _ => 3

       | Jinvokeinterface _ => 5  (* opcode + cp_idx + argn + 0 *)

       | Jdconst d          => sizeOfDConst d
       | Jfconst f          => sizeOfFConst f
       | Jiconst i          => sizeOfIConst i
       | Jlconst l          => sizeOfLConst l
       | Jiinc a            => sizeOfIinc a
       | Jlookupswitch a    => sizeOfLookupswitch a
       | Jtableswitch a     => sizeOfTableswitch a
       | Jnewarray a        => sizeOfNewarray a
       | Jret j             => sizeOfVarAcc (Localvar.toInt j)

       | Jgoto lbl          => 5 (* worst case *)
       | Jjsr lbl           => 5 (* worst case *)
       | _ => 1
    )



(* Return a table of labels and their approximate locations,  together
   with a list of labels occurring in goto and jsr instructions. *)

fun approxJumps [] _ labels acc  = (labels, rev acc)
  | approxJumps (h::t) pos labels acc =
    let val pos' = pos + sizeOf h
    in case h of
	   Jgoto lbl  => approxJumps t pos' labels ((Label.toInt lbl, pos)::acc)
	 | Jjsr lbl   => approxJumps t pos' labels ((Label.toInt lbl, pos)::acc)
	 | Jlabel lbl => approxJumps t pos' (Intmap.insert (labels, Label.toInt lbl, pos)) acc
	 | _ => approxJumps t pos' labels acc
    end


(* Use the result of the previous function to find goto and jsr instructions which
   definitely don't have to be wide.  Return a list of booleans which match up with
   goto/jsr instructions (in order), telling you which ones are safe. *)

fun shortJumps code =
    let
	val (labels, jumps) = approxJumps code 0 (Intmap.empty()) []
	fun isOk (j, pos) =
	    let
		val dest = Intmap.retrieve (labels, j)
	    in
		isShort (pos-dest)
	    end
    in
	map isOk jumps
    end


