(*======================================================================*)
(* gf: Grail decompiler                					*)
(*======================================================================*)

open Classdecl

type 'a substitution = ('a * 'a) list
type layout_substitution = string substitution * int substitution * string substitution
type constructors = string * int * (string * string) list

(* val printSubst: substitution -> unit *)

exception gfError of string

open GrailAbsyn
open Bytecode
open Label



val outputDir = ref "."
val make_wrapper = ref false

fun prn s n = print (s ^ ": " ^ Int.toString n ^ "\n")

fun prInstr i = () (* printJvmInstr.prInstr i*)
fun debugprint s = () (* print s*)
fun member _ [] = false
  | member x (h::t) = x = h orelse member x t


fun fst (x,_) = x
fun snd (_,y) = y

fun prn s n = print (s ^ ": " ^ Int.toString n ^ "\n")

fun makeFullFilename dir base ext =
    Path.joinDirFile
       {dir=dir,
	file=Path.joinBaseExt {base=base, ext=SOME ext}}


(* ==================== Various conversions ==================== *)

val fixSlashes = implode o (map (fn #"/" => #"." | c => c)) o explode

fun qnameToJClass name = (* Jvmtype.qualNameToClass name *)
    if name ="" then Jvmtype.class{pkgs = [], name = ""} else 
    let val l = String.tokens (fn c => (c = #".")) name
	val n = length l
	val pkgs = List.take (l,n-1)
	val name' = List.nth (l,n-1)
    in Jvmtype.class {pkgs = pkgs, name = name'}
    end

val qualName = fixSlashes o Jvmtype.qualName
fun Prl s = TextIO.print (s^"\n");

fun tyToJty BOOLEANty    = Jvmtype.Tboolean
  | tyToJty CHARty       = Jvmtype.Tchar
  | tyToJty FLOATty      = Jvmtype.Tfloat
  | tyToJty DOUBLEty     = Jvmtype.Tdouble
  | tyToJty BYTEty       = Jvmtype.Tbyte
  | tyToJty SHORTty      = Jvmtype.Tshort
  | tyToJty INTty        = Jvmtype.Tint
  | tyToJty LONGty       = Jvmtype.Tlong
  | tyToJty (ARRAYty(t)) = Jvmtype.Tarray(tyToJty t)
  | tyToJty (REFty x) = Jvmtype.Tclass (qnameToJClass x)
(*  | tyToJty _ = raise gfError "UNKNOWN TYPE"*)

val rtyToJty = Option.map tyToJty

fun jtyToTy Jvmtype.Tboolean    = BOOLEANty
  | jtyToTy Jvmtype.Tchar       = CHARty
  | jtyToTy Jvmtype.Tfloat      = FLOATty
  | jtyToTy Jvmtype.Tdouble     = DOUBLEty
  | jtyToTy Jvmtype.Tbyte       = BYTEty
  | jtyToTy Jvmtype.Tshort      = SHORTty
  | jtyToTy Jvmtype.Tint        = INTty
  | jtyToTy Jvmtype.Tlong       = LONGty
  | jtyToTy (Jvmtype.Tarray(t)) = ARRAYty(jtyToTy t) 
  | jtyToTy (Jvmtype.Tclass(x)) = REFty(qualName x)

fun refTy t = 
    case t of SOME (REFty x) => x 
	    | _ => raise gfError "trying to extract classname from non-reference type"

val jtyToRty = Option.map jtyToTy

fun jfieldToFieldDef {attrs: attribute list, flags, name, ty} = FDEF(flags, jtyToTy ty, name)


fun mdescToMethodRef (MDESC(ty, name, argtypes)) = 
    let 
	val l = String.tokens (fn c => (c = #".")) name;
	val n = length l;
	val msig = (map tyToJty argtypes, rtyToJty ty);
	val name = List.nth(l,n-1);
	val (pkgs, cname) = 
	    if n=1 then ([],"")
	    else (List.take(l,n-2), List.nth(l,n-2));
    in 
	{name=name, msig=msig, class = Jvmtype.class{pkgs=pkgs, name=cname}}
    end

fun methodRefToMdesc (mref: Bytecode.method_ref) =
    let
	val {class=class, name=name, msig=msig} = mref
	val (args, rty) = msig
	val rty' = jtyToRty rty
	val name' = qualName(class) ^ "." ^ name
	val args' = map jtyToTy args
    in
	MDESC(rty', name', args')
    end

fun fdescToFieldRef (FDESC(fty,name)) =
  let val l = String.tokens (fn c => (c = #".")) name
      val n = length l
      val ty = tyToJty fty
      val name = List.nth(l,n-1)
      val (pkgs, cname) = 
	  if n=1 then ([],"")
	  else (List.take(l,n-2), List.nth(l,n-2))
  in 
      {name=name, ty=ty, class = Jvmtype.class{pkgs=pkgs, name=cname}}
  end;

fun fieldRefToFdesc (fref: Bytecode.field_ref) =
    let
	val {class=class, name=name, ty=ty} = fref
	val ty' = jtyToTy ty
	val name' = qualName class ^ "." ^ name
    in
	FDESC(ty', name')
    end


fun hd [] = raise gfError "Empty list"
  | hd (h::t) = h




(* ================ Grail local function information ================ *)

(* The next function scans the code to determine which labels are targets
   of goto instructions and hence local function entry points.  It might 
   be better to encode this information as metadata. *)

(* The function returns an Intmap which maps the integer n to the index of
   LBL n in the sequence of entry points;  eg,  if the 5th local function 
   (counting from 0) starts at LBL 23,  then the map maps 23 to 5 *)

local
    val entryPoints = ref (Intmap.empty())
in
    fun getEntryPoints l  = 
	let fun invert l n acc = 
		case l of 
			[] => acc
		      | h::t => invert t (n+1) (Intmap.insert(acc,h,n))
	    fun f l acc =
		case l of
		    [] => invert (Binaryset.listItems acc) 0 (Intmap.empty())
		  | (Jgoto tgt)::t => f t (Binaryset.add(acc, Label.toInt tgt))
		  | (Jtableswitch {default, offset, targets})::t =>
		    f t (Vector.foldl (fn (x,b) => Binaryset.add(b, Label.toInt x)) acc targets)
		  | _::t => f t acc
	in
	    entryPoints := f l (Binaryset.empty Int.compare)
	end
	
	
    fun print_ep () = Intmap.app (fn (a,b) => (print (Int.toString a ^ " ---> " ^ Int.toString b ^ "\n"))) (!entryPoints)

    fun nameOfFun l funInfo = 
	let 
	    val n = Label.toInt l
	in
	    case Intmap.peek(!entryPoints, n) of
		NONE => raise gfError ("Missing entry point " ^ Int.toString n)
	      | SOME m => let
		    val (name, params) = nth (funInfo, m)  (* Error with HDatabase *)
			handle _ => ("fn" ^ Int.toString n, [])
		in 
		    (name, map snd params)
		end
	end
	
	fun listEntryPoints () = map fst (Intmap.listItems (!entryPoints))

end (* local *)


(* ================ Grail local variable information  ================ *)

(* Some stuff for determining names and types of local variables.
   This is obtained from the LOCALVAR attibute in the classfile.
   We assume that this has been created by the gdf backend.
   Maybe some consistency checks would be in order. *)

(* This is still a bit messy.  Usualy we want to look up information
   by local var index,  but sometimes we need to look it up by name
   (because when we're parsing args on the stack we store absyn
   Values).  This means that we have to keep two different data
   structures. *)


local
    val varInfo = ref (Intmap.empty())
    val nameInfo = ref (Binarymap.mkDict String.compare)
    val varnum = Localvar.toInt

in
    fun transVarInfo {from:  label,
		      thru:  label,
		      name:  string,
		      ty:    Jvmtype.jtype,
		      index: Localvar.index} 
      = (Localvar.toInt index, 
	 name,
	 jtyToTy ty)
	

    fun insertVarInfo ({from, thru, name, ty, index} ,t) = 
	Intmap.insert(t, varnum index, (jtyToTy ty, name))

    fun insertNameInfo ({from, thru, name, ty, index} ,t) = 
	Binarymap.insert(t, name, jtyToTy ty)

    fun setVarInfo v = (
	varInfo := List.foldr insertVarInfo (Intmap.empty()) v;
	nameInfo := List.foldr insertNameInfo (Binarymap.mkDict String.compare) v
    )

    fun typeAndNameOfVar n = 
	case Intmap.peek (!varInfo, n) of
	    NONE => raise gfError ("Can't find local variable " ^ Int.toString(n))
	  | SOME p => p
	    
    fun nameOfVar n = 
	let 
	    val m = varnum n 
	in 
	    snd (typeAndNameOfVar (varnum n))
	    handle gfError _ => 
		   if m=0 then "this" else
		   "var" ^ Int.toString (varnum n)
	(* FIX:  TO DEAL WITH 'this' PROBLEM *)
	end


    fun typeOfVar var = 
	case Binarymap.peek (!nameInfo, var) of
	    NONE => raise gfError ("Can't find local variable " ^ var)
	  | SOME p => p

end (* local *)


(* ================ Parsing the bytecode ================ *)

fun chkStk s n = 
let
    val len = length s 
in
    if len < n then raise gfError "Too few items on stack"
    else if len = n then ()
	 else raise gfError "Too many items on stack"
end


(* Now we're trying to parse a fun body;  this involves collecting
   let decs then the result of the function.  This is tricky because 
   it's difficult to tell where the decs end and the result begins.
   The functions below scan the bytecode looking for primops then
   do some lookahead to try and work out what's coming.  The "Thing" 
   datatype below describes what's happening.  It's particularly
   messy when the result of the function is "if ..." because by the
   time we see the if we've already got two values on the stack,
   which might also have been args for a primop.  *)

datatype Thing =
    VAL of PrimOp 
  | VOID of PrimOp
  | NOTHING
  | CASE of string * int * int * Label.label list
  | IF of Value * Value * Test * Label.label

fun split [] = raise gfError "Premature end of bytecode"
  | split (h::t) = (h,t)


fun getPrimOp C =
    let 
	val () = prInstr (hd C)
	fun mkBinOp oper stk = (* Poor error checking *)
	    case stk of 
		[] => raise gfError "Too few arguments for BinOp"
	      | [_] => raise gfError "Too few arguments for BinOp"
	      | [x, y] =>BINop(oper, x, y) 
	      | _ => raise gfError "Too many arguments for BinOp"

	fun getVar [] = raise gfError "Looking for value on empty stack"
	  | getVar (h::_) = 
	    let in case h of
		VARval v => v
	      | _ => raise gfError "Seeking variable on stack,  but found something else"
	    end
    
	fun mkIF stk tst label = 
	    case stk of
		[]  => raise gfError "Too few arguments for test"
	      | [_] => raise gfError "Too few arguments for test"
	      | [NULLval (_,m1), NULLval (_,m2)] => IF(NULLval ("java.lang.Object", m1),
					     NULLval ("java.lang.Object", m2),
					     tst, label)
	      | [NULLval (_,m1), VARval y] => IF(NULLval (tyToString (typeOfVar y), m1), 
					    VARval y, tst, label)	       
	      | [VARval x, NULLval (_,m2)] => IF(VARval x, 
					    NULLval (tyToString(typeOfVar x), m2), tst, label)
	      | [x,y] => IF(x, y, tst, label)
	      | _ => raise gfError "Too many arguments for test"


	fun mkCASE stk (default, offset, targets) = 
	    case stk of 
		[] => raise gfError "Missing argument for case statement"
	      | [VARval x] =>  
		let val low = Int32.toInt offset
		    val high = low + Vector.length targets -1
		    val targets' = Vector.foldr (fn (v,l) => v::l) [] targets
(*		    val () = app (fn l => (print (Label.toString l ^ "\n"))) targets'*)
		    val last = Vector.sub (targets, Vector.length targets - 1)
		    val () = if default = last then () 
			     else raise gfError "Bad default value in case statement"
		in
		    CASE (x, low, high, targets')
		end
	      | [_] => raise gfError "Need variable on stack for case"
	      | _ => raise gfError "Too many items on stack at start of case statement"

	fun mkGET stk =
	    case stk of 
		[]  => raise gfError "Too few arguments for array load"
	      | [_] => raise gfError "Too few arguments for array load"
	      | [arr,i] => VAL(GETop(arr,i))
	      | _ => raise gfError "Too many arguments for array load"

	fun mkSET stk =
	    case stk of 
		[VARval(arr),i,NULLval (_,m)] => 
		let val x = 
		    (case typeOfVar arr of
			 ARRAYty ty => tyToString ty
		       | _ => "java.lang.Object")
		in 
		    VOID(SETop(VARval(arr),i,NULLval (x, m)))
		end
	      | [arr,i,a] => VOID(SETop(arr,i,a))
	      | _ => raise gfError "Wrong number of arguments for array store"


	fun getCl (REFty s) = s | getCl _ = raise gfError "Oops: trying to get class from non-ref type"

	fun fixNulls argtys stk = (* invoke* arguments may contain nulls;  
				     we have to get the type from the mdesc *)
            ListPair.map (fn(a,b) => 
			    case b of NULLval (_,m) => NULLval (getCl a, m) 
				    | _ => b) 
	    (argtys, stk)
	    

	fun eatValues C stk =
	    let
		val (h,t) = split C
		val () = (debugprint ("eat "); prInstr h)
	    in
		case h of 
		    Jiload n  => eatValues t (VARval (nameOfVar n)::stk)
		  | Jfload n  => eatValues t (VARval (nameOfVar n)::stk)
		  | Jaload n  => eatValues t (VARval (nameOfVar n)::stk)
		  | Jiconst n => eatValues t (INTval (Int32.toInt n)::stk)
		  | Jfconst x => eatValues t (FLOATval (Real32.toReal x)::stk)
		  | Jsconst s => eatValues t (STRINGval (String.toString s)::stk)
		  | Jaconst_null => eatValues t (NULLval ("the.great.spoon.from.the.void",
							  Metadata.nextMarker())::stk)
					      (*** DANGER ****) (* FIX *)
		  | Jdup => eatValues t stk (* Mmmm...? *)
		  | _ => (C, rev stk)
	    end

	val (D, stk) = eatValues C []
	val (h,C') = split D
	val () = (debugprint (Int.toString(length stk)^ " > "); prInstr h)
    in
	case h of
	    Jiadd => (VAL(mkBinOp ADDop stk), C')
	  | Jisub => (VAL(mkBinOp SUBop stk), C')
	  | Jimul => (VAL(mkBinOp MULop stk), C')
	  | Jidiv => (VAL(mkBinOp DIVop stk), C')
	  | Jirem => (VAL(mkBinOp MODop stk), C')

	  | Jiand => (VAL(mkBinOp ANDop stk), C')
	  | Jior  => (VAL(mkBinOp ORop stk), C')
	  | Jixor => (VAL(mkBinOp XORop stk), C')
	  | Jishl => (VAL(mkBinOp SHLop stk), C')
	  | Jishr => (VAL(mkBinOp SHRop stk), C')
	  | Jiushr => (VAL(mkBinOp USHRop stk), C')

	  | Jfadd => (VAL(mkBinOp ADDop stk), C')
	  | Jfsub => (VAL(mkBinOp SUBop stk), C')
	  | Jfmul => (VAL(mkBinOp MULop stk), C')
	  | Jfdiv => (VAL(mkBinOp DIVop stk), C')
	  | Jfrem => (VAL(mkBinOp MODop stk), C')

	  | Ji2f => 
	    let 
		val () = chkStk stk 1
		val v = VARval(getVar stk)
	    in 
		(VAL(ITOFop(v)), C')
	    end

	  | Jf2i => 
	    let 
		val () = chkStk stk 1
		val v = VARval(getVar stk)
	    in 
		(VAL(FTOIop(v)), C')
	    end

	  | Jarraylength => 		
	    (case stk of
		 [obj] => (VAL(LENGTHop(obj)), C')
	       | _ => raise gfError "Must be 1 value on stack"
	    )

	  | Jnew class => 
	    let 
		val () = chkStk stk 0
		val () = let 
			     val d = hd C' 
			 in case d of
			     Jdup => ()
			   | _ => raise gfError "No dup after new"
			 end
		(* consume the constructor args *)
		val pop = getPrimOp ((Jiload(Localvar.fromInt 0)) :: (tl C'))
		    (* This is a bit nasty;  invokespecial requires 
		       an object on the stack.  We've only got a dup,
		       so let's replace it with an arbitrary iload 0
		       (which is going to be discarded anyway) *)
	    in
		let 
		    fun insDots [] = ""
		      | insDots [h] = h
		      | insDots (h::t) = h ^ "." ^ insDots t
		    fun stripMname (MDESC(rty, name, types)) =
			let 
			    val l = String.tokens (fn c => c = #".") name
			    val l' = List.take (l, length l - 1)
			    val name' = insDots l'
			in
			    MDESC(rty, name', types)
			end
			
		in case pop of 
		       (VOID(INVOKESPECIALop(v,md,args)), C'') => (VAL(NEWop(stripMname md,args)), C'')
		     | (VAL _, _) => raise gfError "Constructor returns value"
		     | _ => raise gfError "Expecting invokespecial after new"
		end
	    end

	  | Jcheckcast cref => 
		let 
		    val () = chkStk stk 1 
		    val class = case cref of 
			CLASS jclass => jclass
		      | ARRAY _ => raise gfError "Trying to cast to array type"
		in
		    (VAL(CHECKCASTop(qualName class, getVar stk)), C')
		end 
	    
	  | Jinstanceof cref => 
		let 
		    val () = chkStk stk 1 
		    val class = case cref of 
			CLASS jclass => jclass
		      | ARRAY _ => raise gfError "Trying to call 'instanceof' on array type"
		in
		    (VAL(INSTANCEop(qualName class, getVar stk)), C')
		end 
	    
	  | Jinvokestatic mref =>
		let 
		    val mdesc = methodRefToMdesc mref
		    val (MDESC(rty,_,argtys)) = mdesc
		    val () = chkStk stk (length argtys)
		in
		    case rty of 
			NONE => (VOID(INVOKESTATICop (mdesc, fixNulls argtys stk)), C')
		      | _    => (VAL(INVOKESTATICop (mdesc, fixNulls argtys stk)), C')
		end 
       
	  | Jinvokevirtual mref =>
		let 
		    val mdesc = methodRefToMdesc mref
		    val (MDESC(rty,_,argtys)) = mdesc
		    val () = chkStk stk (1+length argtys)
		    val obj = getVar stk
		    val args = tl stk
		in
		    case rty of 
			NONE => (VOID(INVOKEVIRTUALop (obj, mdesc, fixNulls argtys args)), C')
		      | _    => (VAL (INVOKEVIRTUALop (obj, mdesc, fixNulls argtys args)), C')
		end
	    
	  | Jinvokespecial mref =>
		let 
		    val mdesc = methodRefToMdesc mref
		    val (MDESC(rty,_,argtys)) = mdesc
		    val () = chkStk stk (1+length argtys)
		    val obj = getVar stk
		    val args = tl stk
		in
		    case rty of 
			NONE => (VOID(INVOKESPECIALop (obj, mdesc, fixNulls argtys args)), C')
		      | _    => (VAL (INVOKESPECIALop (obj, mdesc, fixNulls argtys args)), C')
		end
	    
	  | Jinvokeinterface mref =>
		let 
		    val mdesc = methodRefToMdesc mref
		    val (MDESC(rty,_,argtys)) = mdesc
		    val () = chkStk stk (1+length argtys)
		    val obj = getVar stk
		    val args = tl stk
		in
		    case rty of 
			NONE => (VOID(INVOKEINTERFACEop (obj, mdesc, fixNulls argtys args)), C')
		      | _    => (VAL (INVOKEINTERFACEop (obj, mdesc, fixNulls argtys args)), C')
		end
	    
	  | Jdup => raise gfError "Unexpected dup opcode"
	  | Jgetstatic fref =>
		let 
		    val () = chkStk stk 0
		in
		    (VAL(GETSTATICop (fieldRefToFdesc fref)), C')
		end
	    
	  | Jgetfield fref => 
		let 
		    val () = chkStk stk 1
		    val obj = getVar stk
		in
		    (VAL(GETFIELDop (obj, fieldRefToFdesc fref)), C')
		end
	    
	  | Jputstatic fref =>
		let 
		    val () = chkStk stk 1
		    val input = hd stk
		in
		    (VOID(PUTSTATICop (fieldRefToFdesc fref, input)), C')
		end
	    
	  | Jputfield fref =>
		let 
		    val () = chkStk stk 2
		    val obj = getVar stk
		    val input = hd (tl stk)
		in
		    (VOID(PUTFIELDop (obj, fieldRefToFdesc fref, input)), C')
		end

	  | Jlabel _ => 
	    (* if length C' = 0 then (NOTHING, []) else *)
	    getPrimOp C'  (* Slight problem because first and 
			   last instructions have labels *)

	  | Jif_acmpeq l => (mkIF stk EQtest l, C')
	  | Jif_acmpne l => (mkIF stk NEtest l, C')

	  | Jif_icmpeq l => (mkIF stk EQtest l, C')
	  | Jif_icmpne l => (mkIF stk NEtest l, C')
	  | Jif_icmplt l => (mkIF stk Ltest l, C')
	  | Jif_icmple l => (mkIF stk LEtest l, C')
	  | Jif_icmpgt l => (mkIF stk Gtest l, C')
	  | Jif_icmpge l => (mkIF stk GEtest l, C')

	  | Jfcmpl => 
	    let 
		val nxt = hd C'
		val C'' = tl C'
	    in case nxt of 
		Jifeq l => (mkIF stk EQtest l, C'')
	      | Jifne l => (mkIF stk NEtest l, C'')
	      | Jiflt l => (mkIF stk Ltest l, C'')
	      | Jifle l => (mkIF stk LEtest l, C'')
	      | Jifgt l => (mkIF stk Gtest l, C'')
	      | Jifge l => (mkIF stk GEtest l, C'')
	      | _ => raise gfError "fcmpl must be followed by if---"
	    end
	  | Jifeq l => raise gfError "Found ifeq  without fcmpl"
	  | Jifne l => raise gfError "Found ifne  without fcmpl"
	  | Jiflt l => raise gfError "Found iflt  without fcmpl"
	  | Jifle l => raise gfError "Found ifle  without fcmpl"
	  | Jifgt l => raise gfError "Found ifgt  without fcmpl"
	  | Jifge l => raise gfError "Found ifge  without fcmpl"
	  | Jiaload => (mkGET stk, C')
	  | Jfaload => (mkGET stk, C')
	  | Jaaload => (mkGET stk, C')
	  | Jiastore => (mkSET stk, C')
	  | Jfastore => (mkSET stk, C')
	  | Jaastore => (mkSET stk, C')
	  | Jnewarray _ => raise gfError "NEWARRAY" (*{elem,dim} => 
	    let 
		val () = chkStk stk 1
	    in
	    case dim of 
		1 => ((VAL (EMPTYop(hd stk, jtyToTy elem))), C')
	      | _ => raise gfError "Only one-dimensional arrays yet. Give me a minute."
	    end*)

(* Experimental  *)

          | Jiinc {var, const} => 
	    let 
		val () = chkStk stk 0
	    in
		(VAL (BINop(ADDop, VARval(nameOfVar var), INTval(const))), C')
	    end

(* /Experimental *)


(* Something else experimental *)

	  | Jtableswitch {default, offset, targets} => (mkCASE stk (default, offset, targets), C')

(* End *)

	  | _ => 
	    let in 
		case stk of
		    [] => (NOTHING, D)
		  | [x] => (VAL (VALop x), D)
		  | _ => raise gfError "Too many values on stack"
	    end
    end


fun getFunDec C decs funInfo rty = 
    let 
	fun FIX_RETURN_REFTY p = (* TEMPORARY FIX *)
	    case p
	     of 
		VALop (NULLval (_,m)) => VALop (NULLval (refTy rty, m)) 
	      | _ => p
				  
				  
	fun getRes1 Jreturn = VOIDres
	  | getRes1 (Jgoto n) = FUNres (nameOfFun n funInfo) 
	  | getRes1 x = (prInstr x; raise gfError "Malformed result") 

	fun getRes q i =
	    case q of 
		VOID p => 
		    let in case i of 
			Jreturn => OPres p
		      | _ => raise gfError "Missing return"
		    end
	      | VAL p =>  
		    let in case i of 
			Jreturn => OPres (FIX_RETURN_REFTY p)
		      | _ => raise gfError "Missing return"
		    end
	      | NOTHING => getRes1 i
	      | _ => raise gfError "Malformed result (nested if?)"

	val (q,D) = getPrimOp C
	val () = debugprint "Got primop\n" 
	val (h,D') = split D
    in
	case q of
	    VOID p => getFunDec D ((VOIDdec p)::decs) funInfo rty

	  | VAL p => 
		let in case h of
		    Jistore n => getFunDec D' (VALdec(nameOfVar n, p)::decs) funInfo rty
		  | Jastore n => 
		    let val x = (case p of 
				 VALop(NULLval (_,m)) => 
				 VALdec(nameOfVar n, 
					VALop(NULLval (tyToString (typeOfVar (nameOfVar n)), 
						       m)))

			       | _ => VALdec(nameOfVar n, p))
		    in 
			getFunDec D' (x::decs) funInfo rty
		    end
(* Hmmm... 		  | Jastore n => getFunDec D' (VALdec(nameOfVar n, p)::decs) funInfo rty *)
		  | Jfstore n => getFunDec D' (VALdec(nameOfVar n, p)::decs) funInfo rty
		  | Jreturn => (rev decs, PRIMres(OPres (FIX_RETURN_REFTY p)), D')
		  | _ => raise gfError "Unused value on stack"
		end

	  | NOTHING => (rev decs, PRIMres(getRes1 h), D')

	  | IF  (x, y, tst, dest) =>
		let
		    val (r,E) = getPrimOp D
		    val (h,t) = split E
		    val p1 = getRes r h
		    val (h',t') = split t
		    val () = let in case h' of 
					Jlabel l => if l<>dest 
					then raise gfError "Bad destination in 'if ...'"
				    else ()
		      | _ => raise gfError "Missing label in 'if ...'"
			     end
		    val (r',E') = getPrimOp t'
		    val (h'',t'') = split E'  
		    val p2 = getRes r' h''
		in
		    (rev decs, CHOICEres(x,tst,y,p2,p1), t'')  
			(* This was getting things the wrong way round *)
		end

	  | CASE (x, low, high, targets) => 
	    let 
		fun doCases c p acc = 
		    case c of [] => rev acc
			    | h::t => 
			      let 
				  val (fname, args) = (nameOfFun h funInfo)
			      in
				  doCases t (p+1) ((p,fname,args)::acc)
			      end
						      
	    in
	    if 
		length targets <> high - low + 1 
	    then 
		raise gfError "j" 
	    else 
		(rev decs, CASEres (x, low, high, doCases targets low []), D)
	    end
    end    




fun getFunDecs C funInfo rty =
    let
	fun get [] acc _ = rev acc
	  | get [Jlabel _] acc _ = rev acc (* discard end-of-code marker *)
	  | get C acc first =
	    let
		val (l,C') = split C
		val () = case l of 
		    Jlabel n => ()  (* DON'T REALLY NEED THIS [MAYBE] *)
		  | _ => if first then () (* first funblock has no name *) (* YUCK *)
			 else raise gfError "Missing label at start of function"
		val (decs, result, rest) = getFunDec C [] funInfo rty
	    in 
		get rest ((decs,result)::acc) false
	    end
    in
	get C [] true
    end



fun getFunInfo attrs = Metadata.getFunInfo attrs typeAndNameOfVar 

  (* ONLY call this AFTER we've set varInfo *)
  (* because typeAndNameOfVar depends on contents of varInfo *)

fun jmethodToMethodDef className (m as {flags, name, msig as (m_args, m_ret), attrs}: Classdecl.method_decl) =
let

    val argtypes = map jtyToTy m_args
    val ret = jtyToRty m_ret

    (* FIX: Looking for attributes: tidy this up. *)

    fun getLocalVariables [] = [] (*raise gfError "Can't find local vars"*)
      | getLocalVariables (h::t) = 
	let in case h of 
		   LOCALVAR l => l
		 | _ => getLocalVariables t
	end


    fun getCode [] = raise gfError "Can't find code"
      | getCode (h::t) = 
	let in case h of 
		   CODE l => l
		 | _ => getCode t
	end

    fun getATTR [] = raise gfError "Can't find code"
      | getATTR (h::t) = 
	let in case h of 
		   ATTR l => l
		 | _ => getATTR t
	end

    val codeA = getCode attrs
    val code = #code codeA
		
    val attrs' = #attrs codeA

    val lvars = getLocalVariables attrs'

    val () = setVarInfo lvars       (* Side effects *)
    val () = getEntryPoints code    (* Side effects *)
(*    val () = print_ep() *)
	     
    val argnames = (* formal argument list for method *)
	let
	    val offset = if member M_ACCstatic flags then 0 else 1
	    val n = length argtypes
	in
	    List.tabulate (n, (fn i => (nameOfVar(Localvar.fromInt(i+offset)))))
	end

    val funInfo = getFunInfo attrs
    val () = Metadata.readNullInfo attrs

    val funDecs = getFunDecs code funInfo ret
	handle gfError p  => raise gfError (p ^ " [in " ^ name ^ "] ")

    val ((decs1, result1), rest) = 
	case funDecs of [] => raise gfError "Oops - no funDecs at all"
      | _ => (hd funDecs, tl funDecs)

    val funs =	if (length rest) <> (length funInfo) then (*raise gfError "Corrupt metadata"*)
	    let (* attempt error recovery: empty arglists *)
		val fnames = map (fn n => "fn" ^ Int.toString n) (listEntryPoints())
		val bodies = map (fn (d,r) => FUNbody(d,r)) rest
	    in 
		ListPair.map (fn (n,b) => FDEC(n,[],b)) (fnames, bodies)
	    end
	else (* WOW - this is really bad *)
	    let 
		val bodies = map (fn (d,r) => FUNbody(d,r)) rest
		fun makeFdec ((id,args), body) = FDEC(id, args, body)
	    in 
		ListPair.map makeFdec (funInfo, bodies)
	    end

in
    MDEF(flags, 
	 ret, 
	 name,
	 ListPair.zip(argtypes, argnames),
	 MBODY(decs1, funs, result1)) 
         (* tidy this up *)
end

(* --------- from Diamond.sml *)
(* layout = string * (string * int * (string * string) list) list *)

(* val get_arities : string -> (string * int * (string * string) list) list -> int list *)
fun get_arities tname ts = (* get arities of all constructurs of a type *)
  List.map (fn (cname, ctag, cargs) => if (GrailUtils.is_elem tname (List.map GrailUtils.snd cargs))
                                         then 100+(length cargs) (* rec type *)
                                         else length cargs) ts

(* val print_layout: GrailAbsyn.layout option -> () *)
fun print_layout NONE =  TextIO.print  "ERROR in print_layout: No data layout given"
  | print_layout (SOME layout) =  
     (TextIO.print o String.concat)
     (List.map (fn (tname, xs) => (* (type-name, list-of-constructors) *)
		 " TYPE " ^ tname ^ "\n" ^
                 (String.concat
                 (List.map (fn (cname, ctag, cargs) => (* constructor-name, tag, args *)
			      "   " ^ cname ^ ":\t$ = " ^ Int.toString ctag ^ "\t" ^
                              (String.concat 
                              (List.map (fn (argname, argtype) =>
                                          "<" ^ argname ^ ", " ^argtype ^">, ") 
                                        cargs)) ^ "\n")
                           xs)))
               layout)


(* find an entry in the global layout, matching the arities given here *)
(* val find_t: int list -> GrailAbsyn.layout option -> string option *)
fun find_t arities NONE =  NONE
  | find_t arities (SOME ts) =  
    let 
      fun find_t' ((tname, cs)::ts') =  
           if (arities = get_arities tname cs)
             then SOME (tname, cs)
             else find_t' ts'
        | find_t' _ = raise gfError "find_t: weird stucture in alleged layout type" 
    in
      find_t' ts
    end

(*
val combSubst: (substitution * substitution) -> substitution
val mkSubst: constructors -> constructors -> substitution
val unify_layouts: GrailAbsyn.layout option -> substitution
*)

val emptySubst: layout_substitution = ([], [], [])
(*
fun apply_subst subst tg =
  let
    val z = GrailUtils.lookup tg subst
  in 
    case z of
       NONE => tg
     | SOME tg' => tg' 
  end

fun apply_tag_subst ((subst,_): layout_substitution) tg = apply_subst subst tg
fun apply_field_subst ((_,subst): layout_substitution) fld = apply_subst subst fld
*)
fun printSubst (tys,tags,fields) =  
   TextIO.print ("SUBSTITUTION:\n " ^
    "Types:\n" ^
     (String.concat (List.map (fn (t,t') => "  " ^ t ^ " -> " ^ t' ^"\n") tys)) ^
    "Tags:\n" ^
     (String.concat (List.map (fn (t,t') => "  " ^ Int.toString t ^ " -> " ^ Int.toString t' ^"\n")
                              tags)) ^
     " Fields:\n" ^
     (String.concat (List.map (fn (t,t') => "  " ^ t ^ " -> " ^ t' ^"\n")
                              fields)) ^ "\n")
    
(* to do: check for duplicates! should never occur *)
fun combSubst ((tys,tags,fields), (tys',tags',fields')) = (tys@tys', tags@tags', fields@fields')

fun mkSubst (tname,cs) (tname',cs') =
 let
   (* find a constructor with n arguments in constructor list xs *)
   fun find_c n [] = NONE
     | find_c n ((cname, ctag, cargs)::xs') = 
       if (n = length cargs)
	 then SOME (cname, ctag, cargs)
         else find_c n xs'
   (* adjust name to match either a pointer (R) or an integer (V) *)
   fun fix_nam nam' "int" = "V"^(String.substring(nam',1,1))
     | fix_nam nam' t' = "R"^(String.substring(nam',1,1))
 in 
   List.foldl combSubst emptySubst
     (* create one substitution for each constructor in cs *)
     (List.map (fn (cname, ctag, cargs) =>
                let
                  val real_c = find_c (length cargs) cs'
                in
                  case real_c of
                     NONE => ((*TextIO.print ("No matching constructor found for: "^cname^"\n")   ;*) emptySubst)
                   | SOME (cname',ctag',cargs') => 
                                                   ([(tname,tname')],
                                                    [(ctag,ctag')],
                                                     GrailUtils.zipWith (fn ((nam,t),(nam',t')) => ((if (not (String.compare(t,t')=EQUAL))
           then (*TextIO.print ("WARNING: types of constructor arguments do not match for " ^ cname ^ " " ^ cname'^"\tfound: " ^ t ^ " expecting: " ^ t' ^ "\n")*) ()
           else ())
          ; (nam, fix_nam nam' t)))
                                                                        cargs cargs')
                 end)
              cs)
 end


       
(* val unify_layouts: GrailAbsyn.layout option -> (int list * string list) *)
fun unify_layouts NONE =  emptySubst
  | unify_layouts (SOME ts) =  
    (List.foldl (fn ((tname, cs),subst) => 
                let
                  val arities = get_arities tname cs
                  val z = find_t arities ToyGrailAbsyn.layout_TREELIST
                in
                  case z of
                     NONE => (TextIO.print ("NO type found for " ^ tname ^ "\n") ; subst)
                   | SOME (tname',cs') => ((* TextIO.print ("Found matching type for " ^ tname ^ ":\t" ^ tname' ^ "\n") ;*)
                                           combSubst ((mkSubst (tname, cs) (tname', cs'), subst)))
                end)
              emptySubst
              ts)

(*
      let
   	    fun fname (G.FDESC (_,desc)) = valOf (Path.ext desc)
   	    (* Abuse of path function *)

   	    fun prf (G.FDEF (flags, ty, name)) =
   		if Util.member (Classdecl.F_ACCstatic) flags then ()
   		else print ((Util.fillString (name ^ ": ") 5) ^ G.tyToString ty ^ "\n")

   	    fun fname (G.FDESC (_,desc)) = valOf (Path.ext desc)
   	    (* Abuse of path function *)

   	    fun getArgs C i n = 
   		if i=n then []
   		else (fname (getFieldDesc (C,i))) ::(getArgs C (i+1) n)
   		     
   	    fun pad s = substring (s^":             ", 0, 11)
   		     
   	    fun prCon (N.TYPEcon (cname, (args,_), usage, _)) = 
   		(print (pad cname ^ "$ = " 
   			^ Int.toString (getTagInfo cname)
   			^ ", ");
   		 if usage = N.NOHEAP 
   		 then print " heap-free\n"
   		 else print (" args = ("
   			^ listToString id ", " (getArgs cname 0 (length args))
   			^ (if (cname="Cons")
                              then if ((fname (getFieldDesc (cname,0))) = "V0" andalso (fname (getFieldDesc (cname,1))) = "R1")
                                     then " OK "
                                     else " SWAP "
                              else "")
   			^ (if (cname="Some")
                              then if ((fname (getFieldDesc (cname,0))) = "V0" andalso (fname (getFieldDesc (cname,1))) = "R1")
                                     then " OK "
                                     else " SWAP "
                              else "")
   			^ (if (cname="Node")
                              then if ((fname (getFieldDesc (cname,0))) = "V0" andalso (fname (getFieldDesc (cname,1))) = "R1" andalso (fname (getFieldDesc (cname,2))) = "R2")
                                     then " OK "
                                     else " SWAP "
                              else "")
   			^ ")\n")
   		)
   	in (
   	    print "\n(*\nInstance fields in diamond class:\n";
   	    print "($n is freelist link field, $ is tag value)\n";
   	    app prf fielddefs;
   	    print "\nDatatype mapping:\n";
   	    app prCon constructors;
   	    print "*)\n\n")
   	end
*)
(* ---------- *)

fun decompile classfile (printCert,dataLayout,tagOffset,thySyntax,tFlavour) =
let
    val {flags, this, super, ifcs, fdecls, mdecls, attrs, ...} = Decompile.toClassDecl classfile

    val className = qualName this
    (* val mdecls = List.filter (fn m => (#name m <> "<init>")) mdecls  (* REALLY ???? *)*)


    val () = Metadata.readTypeInfo attrs
    val mds = map (jmethodToMethodDef className) mdecls
    val fields = fdecls
    val fds =  map jfieldToFieldDef fields
    val layout = Metadata.getLayoutInfo attrs

    (* code for matching data layouts *)
    (*
    val _ = TextIO.print ".. GF says Layout info:\n"
    val _ = GrailAbsyn.prLayout layout
    val _ = TextIO.print ".. Layout in German:\n"
    val _ = print_layout layout
    val _ = TextIO.print ".. Bloody stupid hardwired layout is:\n"
    val _ = GrailAbsyn.prLayout ToyGrailAbsyn.layout_TREELIST 
    val _ = TextIO.print ".. Unifying layouts...\n"
    *)
    val subst = unify_layouts layout
    val _ = ToyGrailAbsyn.global_subst := subst
    (*
    val _ = TextIO.print ".. yields this substitution...\n"
    val _ = printSubst subst
    *)

    val cd = GrailAbsyn.CDEF(flags, className, Option.map qualName super, map qualName ifcs, fds, mds,layout)

    (*+
    val () = case layout of
                NONE => TextIO.print "WARNING: empty layout"
              | SOME _ => TextIO.print "SOME layout is here"
    +*)

    (* optionally, also generate a .thy file out of the grail code *)
    val generateThy = not (thySyntax=0)

    val () = if printCert
               then let
		     (* val _ = CertGenC.makeCert2 className mdecls (!outputDir)*)
		     val _ = CertGenC.makeCert className thySyntax tFlavour mdecls (!outputDir)
                     val _ = CertGenC.mk_tactics className thySyntax tFlavour (!outputDir)
                     val _ = if (!make_wrapper)
                               then let 
                                      val wrapStr = CertGenC.getCertMethodFromCert className
                                      val _ = case (!CertGenC.which_wrapper) of
                                                0 => CertGenC.makeWrapper4 className wrapStr (!outputDir) (*default: demo-y4 wrapper*)
                                              | 1 => CertGenC.makeWrapper1 className wrapStr (!outputDir) (*review-y3 wrapper*)
                                              | 2 => CertGenC.makeWrapper0 className wrapStr (!outputDir)
                                              | _ => CertGenC.makeWrapper className thySyntax tFlavour wrapStr (!outputDir) (*pick wrapper based on tactic flavour*)
                                    in
                                     () (* YUCK *)
                                    end
                               else ()
                    in            
                     ()
                    end
               else ()

    (*
    val _ =  if generateThy
               then TextIO.print "Generating Tactic and consumer side certificates"
               else TextIO.print "NOT Generating Tactic and consumer side certificates"
    *)
    (*
    val _ = CertGenC.makeCert2 className mdecls
    val _ = CertGenC.makeCert3 className mdecls
    val _ = CertGenC.mk_tactics className
    *)

    (* print Grail code to stdout *)
    val () = if (!ToyGrailAbsyn.shut_up)
               then ()
               else prClassDef cd

    (* optionally, also generate a .thy file out of the grail code *)
    val () = if generateThy then
               let
                 val file = makeFullFilename (!outputDir) className "thy"
	         val thy = TextIO.openOut file
	         val () = ToyGrailAbsyn.printGrailPROG thy cd (printCert,dataLayout,tagOffset,thySyntax,tFlavour)
	         val () = TextIO.closeOut thy
	       in
                  TextIO.print (" Wrote " ^ file ^ "\n")
	       end
	     else ()
in
    ()
end


(* Matthew wants to be able to read class files from stdin *)
(* Unfortunately mosmllib doesn't appear to provide the necessary
   functions to do this.  Only TextIO and BinIO seem to provide
   access to stdin,  and only TextIO allows you to read in the 
   entire stream;   however,  it gives you a CharVector.vector,
   which is the same a a string.  We want a Word8Vector.vector,
   which has the same representation,  but there's no legal means
   for performing the conversion.  We get round this by using 
   a coercion.  I'm not guaranteeing that this is fully safe
   (but it should be) *)
   

fun decompileStdin thySyntax = 
let
    prim_val fromw8v : CharVector.vector -> Word8Vector.vector = 1 "identity"
    val word8v = TextIO.inputAll TextIO.stdIn
    val v = fromw8v word8v
    val classfile = Classfile.vectorToClass v
in
    decompile classfile thySyntax
end


fun decompileFile (fname: string) thySyntax = 
let
    val classfile = Classfile.inputClassFile fname
in
    decompile classfile thySyntax
end



fun extend filename extn = (* Add extension if necessary *)
   let
       val e = Path.ext filename;
   in
       if e = SOME extn
       then
	   filename
       else
	   Path.joinBaseExt {base=filename, ext=SOME extn}
   end

fun usage () =  TextIO.print "Usage: gf [-l] [-t <syntax>] Java_class_name\nas <syntax> use vcg for current logic, VCG to also generate consumer-side certificate\n"

fun doArgs [] thySyntax' = decompileStdin thySyntax'
  | doArgs ("-h"::t) _ = usage ()
  | doArgs ("-q"::t) thySyntax'  = (ToyGrailAbsyn.shut_up := true ; doArgs t thySyntax')
     (* thySyntax encoding = (printCertificate?, dataLayout, tag_offset, logic *)
  | doArgs ("-L"::t) (p,d,tag,thy,tFlavour)  =  (* set data layout *)
    let in case t of
	 nil => usage() 
       | layoutStr::t' => 
	 let
             val l = case (Int.fromString layoutStr) of
				 NONE => 0
                               | (SOME x) => x
         in
           (ToyGrailAbsyn.data_layout := layoutStr ;   doArgs t' (p,l,tag,thy,tFlavour)) (* set layout *)
         end
    end
  | doArgs ("-T"::t) (p,d,tag,thy,tFlavour)  = (* set tag offset *)
    let in case t of
	 nil => usage() 
       | tagStr::t' => 
	 let
             val tag' = case (Int.fromString tagStr) of
				 NONE => 0
                               | (SOME x) => x
         in
           (ToyGrailAbsyn.tag_offset := tag' ; doArgs t' (p,d,tag',thy,tFlavour)) (* set layout *)
         end
    end
  | doArgs ("-F"::t) (p,d,tag,thy,tFlavour)  = (* set tactic flavour *)
    let in case t of
	 nil => usage() 
       | tagStr::t' => 
	 let
             val tag' = case (Int.fromString tagStr) of
				 NONE => 0
                               | (SOME x) => x
             val (tFlavour, wrapper'') = (tag' mod 10, tag' div 10)
         in
           doArgs t' (p,d,tag,thy,tFlavour) (* set layout *)
         end
    end
  | doArgs ("-W"::t) (p,d,tag,thy,tFlavour)  = (make_wrapper := true ; doArgs t (p,d,tag,thy,tFlavour))
     (* enable certificate generation by setting 100 field in thySyntax *)
  | doArgs ("-C"::t) (p,d,tag,thy,tFlavour)  = doArgs t (true,d,tag,thy,tFlavour) 
     (* enable certificate generation by setting 100 field in thySyntax *)
  | doArgs ("-d"::t) thySyntax' =
    let in case t of
	 nil => usage() 
       | dir::t' => (outputDir := dir ; doArgs t' thySyntax') 
    end
  | doArgs ("-t"::t) (p,d,tag,thy,tFlavour)  =
    let 
    in case t of
	 nil => usage() 
       | z::t' => (* no val => use theory 6 as default *)
	        if (String.isPrefix "-" z) then doArgs (z::t') (p,d,tag,6,tFlavour)
	else 
           case z of
	   	(* default:1 toy:2 bcl:3 *)
	   	"toy" => doArgs t' (p,d,tag,2,tFlavour) (* use ToyGrail 
syntax *)
	     | "bcl" => doArgs t' (p,d,tag, 3,tFlavour)  (* use BytecodeLogic syntax *)
	     | "dal" => doArgs t' (p,d,tag, 4,tFlavour) (* use DA Logic syntax *)
	     | "vcg" => doArgs t' (p,d,tag, 5,tFlavour) (* use VCG Logic syntax *)
	     | "mrg" => doArgs t' (p,d,tag, 6,tFlavour) (* use NULLTP Logic syntax *)
	     | _ => doArgs (z::t') (p,d,tag,6,tFlavour)
	   (*| _ => raise (gdfError (d ^ " is not a recognised logic format"))*)

    end
  | doArgs (h::_) thySyntax' = 
    (TextIO.print ("Reading "^h^"...\n") ; 
     decompileFile (extend h "class") thySyntax')    (* should really do this via classpath *)


val () = let
           val args = CommandLine.arguments() 
         in
           doArgs args (false,0,0,6,4)
           (* default: don't print cert, data layout 0, tag offset 0, thySynt 6, tact flavour 4 *)
         end
