(*
 mik:
 - inspired by Lennart's mkTheory which generates representation predicates
 - added portions and internal layered separation
 - connected it with Camelot using pickled "hints"
 *)

structure DT2Pred =
struct

open Util

structure A = Absyn
type ConsNm = A.ConsNm		 
type DataNm = A.DataNm
type TVarNm = A.TVar
type Ty = A.Ty
(* in future, when field names are in Camelot:
type FldNm = A.FldNm
type TypeCon = A.TypeCon
type TypeConL = A.TypeConL
type TypeDec = A.TypeDec
type TypeDecL = A.TypeDecL
 *)
(* temporary *)  
type FldNm = string
datatype TypeCon = TYPEcon of ConsNm * ((Ty * FldNm) list) * A.HeapUsage
type TypeConL = Location.Location * TypeCon
datatype TypeDec = TYPEdec of (TVarNm list) * DataNm * (TypeConL list)
type TypeDecL = Location.Location * TypeDec

(* Add default field names.
   Originally, each field name was unique -
   this explains the passing of n all over the place.
   Now, field names are named fld1, fld2 according
   indexed by their position in the constructor.
*)
					
(* addFldConL : (A.TypeConL, int) -> (TypeConL, int) *)
fun addFldConL ((loc, A.TYPEcon (consNm, tys, heapUsage)), n) =
    let
	fun f [] n = ([], n)
	  | f (ty :: rest) n =
	    let
		val (rest_tyflds, m) = f rest (n + 1)
	    in ((ty, "fld" ^ (Int.toString n)) :: rest_tyflds, m)
	    end
	val (tyflds, m) = f tys n
    in ((loc, TYPEcon (consNm, tyflds, heapUsage)), m)
    end

(*  addFldDec : (A.TypeDec, int) -> (TypeDec, int) *)
fun addFldDec (A.TYPEdec (tvars, dataNm, typeConLs0), n) =
    let
 	fun f (typeConL0, (done, n1)) =
	    let
		val (typeConL, n2) = addFldConL (typeConL0, 1) (* (typeConL0, n1) *)
	    in (typeConL :: done, n2)
	    end
	val (typeConLs, m) = foldl f ([],n) typeConLs0
    in (TYPEdec (tvars, dataNm, typeConLs), m)
    end
	
(* addFldDecs : A.TypeDecL list -> TypeDecL list *)
fun addFldDecs typeDecLs0 =
    let
	fun f ((loc, typeDec0), (done, n)) =
	    let
		val (typeDec, m) = addFldDec (typeDec0, n)
	    in ((loc, typeDec) :: done, m)
	    end
	val (typeDecLs, _) = foldl f ([],1) typeDecLs0
    in typeDecLs
    end
(* end temporary *)

(**************************************************)
(* massage data in suitable formats from typeDecs *)
(**************************************************)

type DataFM = (DataNm, ConsNm list * ConsNm list) FM.dict
(* first list: boxed constructors
   second list: heap-free constructors
 *)
type FieldFM = (FldNm, DataNm) FM.dict
(* Ignore fields whose type is not a datatype.
   Therefore, ignore iheap except for the TAG element
 *)
type ConsFM = (ConsNm, DataNm * FldNm list * FieldFM * int) FM.dict
(* Defined for each boxed constructor consNm.
   DataNm: the datatype which consNm constructs
   FldNm list: the fields of consNm which have data types
   int: the TAG number of consNm
 *)
type HFConsFM = (ConsNm, DataNm) FM.dict
(* Defined for each heap-free constructor *)
		
type PreAddr = string list
		      (* [dataNm, consNm1, fldNm1, consNm2, fldNm2, ..., consNmN, fldNmN] *)
		      (* a sequence that can be completed to a portion or data address *)
type PortnAddr = string list
			(* pre_addr @ [consNm] *)
			(* address of an occurrence of consNm in an unfolded type term; *)
			(* these correspond to heap portions *)
type RecAddr = string list
		      (* pre_addr @ [dataNm] *)
		      (* address of an occurrence of dataNm in an unfolded type term; *)
		      (* used to indicate separation of unfoldings along this dataNm *)
type UnfoldFM = (DataNm, PreAddr list) FM.dict
(* address prefixes of recursive occurrences of dataNm
   within the recursive sum-and-product type term of dataNm
 *)

fun isTopPortn portn = length portn = 2
fun getPortnData portn = hd portn
fun getPortn1stCons portn = List.nth (portn, 1)
fun getPortnFld portn = List.nth (portn, 2)
fun getPortnLastCons portn = List.last portn
				       
fun portnTail consFM portn =
    let
	val tail3 = tl (tl (tl portn))
	val (dataNm, _, _, _) =
	    FMfind_err (consFM, hd tail3) "in portnTail"
    in
	dataNm :: tail3
    end

fun preAddrTail default_dataNm consFM portn =
    let
	val tail3 = tl (tl (tl portn))
	val dataNm =
	    if tail3 = [] then default_dataNm
	    else
		let val (dataNm, _, _, _) = FMfind_err (consFM, hd tail3) "in preAddrTail"
		in dataNm end
    in
	dataNm :: tail3
    end

fun dropLastAddrBit portn = List.take (portn, length portn - 1)
fun isTopPreAddr pre_addr = length pre_addr = 1
fun getPreAddr portn_postfix portn =
    minusPostfix (tl portn_postfix) portn
fun getPreAddrLastCons pre_addr =
    List.nth (pre_addr, (length pre_addr) - 2)
fun complete_portn pre_addr s = pre_addr @ [s]
fun join_portns pre_addr portn =
    pre_addr @ (tl portn)

    
fun portnNm [_] = "<empty>"
  | portnNm portn = glue_list "_" (tl portn)
fun recAddrNm recAddr = glue_list "_" (tl recAddr)

(* getConsInfo : DataNm -> TypeConL list ->
		 ConsNm list * ConsFM
		 * ConsNm list * HFConsFM
		 * FldNm list
 *)
fun getConsInfo _ _ [] = ([], [], [], [], [])
  | getConsInfo n dataNm ((_, TYPEcon (consNm, tyflds, A.HEAP)) :: rest) =
    let
	val (rest_consNms, rest_consFMlist,
	     rest_hfConsNms, rest_hfConsFMlist,
	     rest_rfldNms) = getConsInfo (n + 1) dataNm rest
	val data_tyflds = List.filter (fn (ty, _) => case ty of A.CONty _ => true | _ => false) tyflds
	val (_, rfldNms) = ListPair.unzip data_tyflds
	fun f (A.CONty (_, dataNm), fldNm) = (fldNm, dataNm)
	  | f _ = raise Fail "" (* never happens, just to prevent a warning *)
	val fldFM = listToFM (map f data_tyflds)
    in (consNm :: rest_consNms, (consNm, (dataNm, rfldNms, fldFM, n)) :: rest_consFMlist,
	rest_hfConsNms, rest_hfConsFMlist,
	rfldNms @ rest_rfldNms)
    end
  | getConsInfo n dataNm ((_, TYPEcon (consNm, tyflds, A.NOHEAP)) :: rest) =
    let
	val (rest_consNms, rest_consFMlist,
	     rest_hfConsNms, rest_hfConsFMlist,
	     rest_rfldNms) = getConsInfo n dataNm rest
    in (rest_consNms, rest_consFMlist,
	consNm :: rest_hfConsNms, (consNm, dataNm) :: rest_hfConsFMlist,
	rest_rfldNms)
    end
	
(* getDataConsInfo : TypeDecL list ->
		      DataNm list * DataFM
		      * ConsNm list * ConsFM
		      * ConsNm list * HFConsFM
		      * FldNm list
		      * Portn list * PortnFM
*)
fun getDataConsInfo [] = ([], FM.mkDict String.compare,
			  [], FM.mkDict String.compare,
			  [], FM.mkDict String.compare,
			  [])
  | getDataConsInfo ((_, TYPEdec (_, dataNm, typeConLs)) :: rest) =
    let
	val (rest_dataNms, rest_dataFM,
	     rest_consNms, rest_consFM,
	     rest_hfConsNms, rest_hfConsFM,
	     rest_rfldNms)
	= getDataConsInfo rest
	val (consNms, consFMlist,
	     hfConsNms, hfConsFMlist,
	     rfldNms)
	= getConsInfo 1 dataNm typeConLs
    in (dataNm :: rest_dataNms, FM.insert (rest_dataFM, dataNm, (consNms, hfConsNms)),
	consNms @ rest_consNms, addListToFM rest_consFM consFMlist,
	hfConsNms @ rest_hfConsNms, addListToFM rest_hfConsFM hfConsFMlist,
	rfldNms @ rest_rfldNms)
    end


(* getDataConsPortnInfo : TypeDecL list -> int ->
			  DataNm list * DataFM
                          * ConsNm list * ConsFM
                          * ConsNm list * HFConsFM
		          * FldNm list
		          * Portn list * UnfoldFM * Portn list
 *)
fun getDataConsPortnInfo typeDecLs level =
    let
	val (dataNms, dataFM,
	     consNms, consFM,
	     hfConsNms, hfConsFM,
	     rfldNms) = getDataConsInfo typeDecLs
	fun get_data_portns prefix boundNms dataNm =
	    if Set.member(boundNms, dataNm) then []
	    else
		    let
			val (consNms, _) =
			    FMfind_err (dataFM, dataNm) "in getDataConsPortnInfo (dataNm)"
			val new_boundNms = Set.add (boundNms, dataNm)
			fun f consNm =
			    let
				val (_, flds, fldFM, _) =
				    FMfind_err (consFM, consNm) "in getDataConsPortnInfo (consNm)"
				val fld_dataNms =
				    map (fn fldNm =>
					    (fldNm, FMfind_err (fldFM, fldNm) "in getDataConsPortnInfo (fldNm)")) flds
				fun ff (fldNm, dataNm) =
				    get_data_portns (prefix @ [consNm, fldNm]) new_boundNms dataNm
			    in
				(prefix @ [consNm]) :: (List.concat (map ff fld_dataNms))
			    end
		    in
			List.concat (map f consNms)
		    end
	fun get_data_preAddrs prefix boundNms dataNm =
	    if Set.member(boundNms, dataNm)
	    then
		if getPortnData prefix = dataNm
		then [prefix]
		else []
	    else
		    let
			val (consNms, _) = FMfind_err (dataFM, dataNm) "in get_data_preAddrs (dataNm)"
			val new_boundNms = Set.add (boundNms, dataNm)
			fun f consNm =
			    let
				val (_, flds, fldFM, _) =
				    FMfind_err (consFM, consNm) "in get_data_preAddrs (consNm)"
				val fld_dataNms =
				    map (fn fldNm =>
					    (fldNm, FMfind_err (fldFM, fldNm) "in get_data_preAddrs (fldNm)")) flds
				fun ff (fldNm, dataNm) =
				    get_data_preAddrs (prefix @ [consNm, fldNm]) new_boundNms dataNm
			    in
				List.concat (map ff fld_dataNms)
			    end
		    in
			List.concat (map f consNms)
		    end
	fun f (dataNm, (done_portns, done_unfoldFM)) =
	    let
		val dataNmSet = Set.singleton String.compare dataNm
		val portns = get_data_portns [dataNm] (Set.empty String.compare) dataNm
		val unf_preAddrs = get_data_preAddrs [dataNm] (Set.empty String.compare) dataNm
		val unfoldings = map (fn p => (dataNm, p)) unf_preAddrs
		(* val _ = println ("unfoldings of " ^ dataNm ^ ": " ^ glue_list ", " (map portnNm unf_preAddrs)) *)
		(* val _ = println ("portions of " ^ dataNm ^ ": " ^ glue_list ", " (map portnNm portns))  *)
	    in
		(portns @ done_portns, mergeListToFM done_unfoldFM unfoldings)
	    end
	val (portns, unfoldFM) = foldl f ([], FM.mkDict String.compare) dataNms
	fun get_long_portns [] = []
	  | get_long_portns (dataNm :: rest) =
	    let
		val unfs = case FM.peek (unfoldFM, dataNm) of NONE => []
							    | SOME l => l
		fun append_data_portns unf =
		    map (fn l => unf @ tl l) (List.filter (fn p => getPortnData p = dataNm) portns)
		val done_rest = get_long_portns rest
	    in
		List.concat (map append_data_portns unfs) @ done_rest
	    end
	val long_portns = if level > 2 then get_long_portns dataNms else []
    in
	(dataNms, dataFM,
	 consNms, consFM,
	 hfConsNms, hfConsFM,
	 rfldNms,
	 portns, unfoldFM, long_portns)
    end

(******************************************************************************)
(******************************************************************************)
(****************************** theory generation *****************************)
(******************************************************************************)
(******************************************************************************)
	
(*************************************)
(* Representation of datatype values *)
(*************************************)
    
fun reprPred dataNm =
    "models_" ^ dataNm

fun consPostfix consNm =
    "constr_" ^ consNm
    
fun reprRuleNm dataNm consNm =
    reprPred dataNm ^ "_" ^ consPostfix consNm

fun dataNm_on_hl dataNm =
    "fmap_lookup (oheap h) l = Some " ^ dataNm
(*    "s\\<guillemotleft>l\\<guillemotright> = Some " ^ dataNm *)

fun tag_on_hl tag =
    "h<l\\<bullet>TAG> = " ^ (Int.toString tag)

fun dataNm_tag_on_hl dataNm tag =
    dataNm_on_hl dataNm ^ " \\<and> " ^ tag_on_hl tag

fun fld_loc fldNm =
    "h\\<lfloor>l\\<diamondsuit>" ^ fldNm ^ "\\<rfloor> = Ref l_" ^ fldNm
    
fun fld_locs_in_set setNm fldNm =
    "(l_" ^ fldNm ^ ",h,locs_" ^ fldNm ^ ") \\<in> " ^ setNm
    
fun fld_loc_locs_in_set setNm fldNm =
    fld_loc fldNm ^ " \\<and> " ^ fld_locs_in_set setNm fldNm

fun fld_repr fldFM fld =
    let
	val fld_dataNm = FMfind_err (fldFM, fld) "in fld_repr"
    in
	fld_loc_locs_in_set (reprPred fld_dataNm) fld
    end

fun reprStuff (rfldNms, dataFM, dataNms, consFM, consNms) =
    let
	(* type and field names declaration *)
	val uniq_rfldNms = Set.listItems (Set.addList (Set.empty String.compare, rfldNms))
	val dataConsts = "consts"
			 ^ wrap_list "\n  " dataNms " :: cname"
			 ^ wrap_list "\n  " uniq_rfldNms " :: rfldname"
			 ^ "\n  TAG :: ifldname\n"
        (* models_<type> declaration *)
	val reprConsts =
	    "types\n  models_pred = \"(locn \\<times> heap \\<times> locn list) set\"\n\n" ^
	    "consts" ^ wrap_list "\n  " (map reprPred dataNms) " :: models_pred" ^ "\n"
	(* models_<type> definition *)
	fun reprRule consNm =
	    let
		val (dataNm, flds, fldFM, tag) = FMfind_err (consFM, consNm) "in reprRule"
	    in
		"  " ^ reprRuleNm dataNm consNm ^ ":" ^
		"\n  \"\\<lbrakk> " ^ dataNm_tag_on_hl dataNm tag ^
		wrap_list ";\n    " (map (fld_repr fldFM) flds) "" ^
		(case flds of [] => ""
			    | _ => ";\n    " ^ glue_wrap_list "; " "\\<not>(l mem locs_" flds ")") ^
		"\\<rbrakk>\n  \\<Longrightarrow> " ^
		(case flds of [] => "(l, h, [l])"
			    | _ => "(l, h, l#(" ^ glue_wrap_list " @ " "locs_" flds "" ^ "))") ^
		" \\<in> " ^ reprPred dataNm ^ "\"\n"
	    end
	val reprRules =
	    "inductive " ^ glue_list " " (map reprPred dataNms) ^ " intros\n" ^
	    glue_list "" (map reprRule consNms)
	val allReprPreds =
	    glue_list "_" (map reprPred dataNms)
	(* intro lemmas for models_<type> *)
	fun reprIntroLemma consNm =
	    let
		val (dataNm, flds, fldFM, tag) = FMfind_err (consFM, consNm) "in reprIntroLemma"
	    in
		"lemma " ^ reprRuleNm dataNm consNm ^ "_intro:" ^
		"\n  \"" ^ dataNm_tag_on_hl dataNm tag ^
		wrap_list "\n  \\<and> " (map (fld_repr fldFM) flds) "" ^
		(case flds of [] => ""
			    | _ => "\n " ^ wrap_list " \\<and> \\<not> (l mem locs_" flds ")") ^
		"\n  \\<Longrightarrow> (l, h, [l]" ^ wrap_list " @ locs_" flds "" ^ ")" ^
		" \\<in> " ^ reprPred dataNm ^ "\"" ^
		"\n  by (simp?, rule " ^ reprRuleNm dataNm consNm ^ ", auto)\n"
	    end
	val reprIntros =
	    mini_comment "introduction lemmas" ^ "\n" ^
	    glue_list "\n" (map reprIntroLemma consNms)
	(* elim lemmas for models_<type> *)
	fun reprElimLemma dataNm =
	    let
		val (consNms, hfConsNms) = FMfind_err (dataFM, dataNm) "in reprElimLemma (dataNm)"
		fun on_hl consNm =
		    let
			val (_, flds, fldFM, tag) = FMfind_err (consFM, consNm) "in reprElimLemma (consNm)"
			val nl = "\n      "
		    in
			tag_on_hl tag ^
			(case flds of
			     [] => " \\<and> locs = [l]"
			   | _ =>
			     nl ^ "\\<and> (\\<exists>" ^
			     wrap_list " l_" flds "" ^
			     wrap_list " locs_" flds "" ^ "." ^
			     nl ^ "  locs = [l]" ^ wrap_list " @ locs_" flds "" ^
			     nl ^ " " ^ wrap_list " \\<and> (\\<not> l mem locs_" flds ")" ^
			     wrap_list (nl ^ "  \\<and> ") (map (fld_repr fldFM) flds) "" ^ ")")
		    end
	    in
		"lemma " ^ reprPred dataNm ^ "_elim:" ^
		"\n  \"(l, h, locs) \\<in> " ^ reprPred dataNm ^
		"\n  \\<Longrightarrow> " ^ dataNm_on_hl dataNm ^
		"\n  \\<and> (" ^
		glue_wrap_list "\n     \\<or>\n     " "(" (map on_hl consNms) ")" ^
		")\"" ^
		"\n  apply (erule " ^ allReprPreds ^ ".elims)" ^
		"\n  apply auto" ^
		"\n  done\n"
	    end
	val reprElims = 
	    mini_comment "elimination lemmas" ^ "\n" ^
	    glue_list "\n" (map reprElimLemma dataNms) ^ "\n"
    in
	section "Representation of datatype values" ^
	glue_list "\n" [dataConsts, reprConsts, reprRules, reprIntros, reprElims]
    end
	
(*******************************)
(* Portions of datatype values *)
(*******************************)

fun portnPred portn =
    "type_" ^ (getPortnData portn) ^ "_portion_" ^ portnNm portn
fun portRuleNm portn consNm =
    portnPred portn ^ "_" ^ consPostfix consNm

fun fld_portion_on_hl consFM (fld, portn) =
    let
	val sub_portn = portnTail consFM portn
	val pred = portnPred sub_portn
    in
	fld_loc_locs_in_set pred fld
    end

			      
fun portionsPerCons consPortnsFM
		   (dataFM, dataNms, consFM, consNms, all_portns, unfoldFM)
		   final_consNm =
    let
	val (portns, preds) = FMfind_err (consPortnsFM, final_consNm) "in portionsPerCons (final_consNm)"
	(* models_<type>_portion_<zaddr> declarations *)
	val portionsConsts =
	    "consts\n" ^ wrap_list "  " preds " :: models_pred\n"
	(* models_<type>_portion_<zaddr> definitions *)
	fun portn_cons_bits (portn, pred) consNm =
	    let 
		val (dataNm, _, _, tag) = FMfind_err (consFM, consNm) "in portionsPerCons (consNm)"
		val unf_portns =
		    if isTopPortn portn
		    then let
			    val last_consNm = getPortnLastCons portn
			    fun f pre_addr = complete_portn pre_addr last_consNm
			in
			    map f (case FM.peek (unfoldFM, getPortnData portn) of NONE => []
										| SOME l => l)
			end
		    else [portn]
		val this_cons_portns =
		    List.filter (fn p => getPortn1stCons p = consNm) unf_portns
		val flds =
		    map (fn p => getPortnFld p) this_cons_portns
		val fld_portns =
		    ListPair.zip (flds, this_cons_portns)
		val locs =
		    if isTopPortn portn andalso consNm = getPortnLastCons portn
		    then "[l]" ^ wrap_list " @ locs_" flds ""
		    else
			if flds = []
			then "[]"
			else glue_wrap_list " @ " "locs_" flds ""
	    in
		(dataNm, tag, flds, fld_portns, locs)
	    end
	fun portionRule (portn, pred) consNm =
	    let
		val (dataNm, tag, flds, fld_portns, locs) =
		    portn_cons_bits (portn, pred) consNm
	    in
		"  " ^ portRuleNm portn consNm ^ ":" ^
		"\n  \"\\<lbrakk> " ^ dataNm_tag_on_hl dataNm tag ^
		wrap_list "\n  \\<and>" (map (fld_portion_on_hl consFM) fld_portns) "" ^
		"\\<rbrakk>\n  \\<Longrightarrow> (l, h, " ^ locs ^ ")" ^
		" \\<in> " ^ pred ^ "\"\n"
	    end
	fun portionRules (portn, pred) =
	    let
		val (consNms, hfConsNms) = FMfind_err (dataFM, getPortnData portn) "in portionRules"
	    in
		mini_comment ("portion " ^ portnNm portn) ^ "\n" ^
		concat (map (portionRule (portn, pred)) consNms)
	    end
	val portionsRules =
	    "inductive " ^ wrap_list "" preds " " ^ "intros\n" ^
	    concat (map portionRules (ListPair.zip (portns, preds))) ^ "\n"
	val allPortnPreds = glue_list "_" preds
									   
	(* lemmas: models_<type> implies  models_<type>_portion_<zaddr> *)
	fun portionIfRepr (portn, pred) =
	    "  \"(l,h,locs) \\<in> " ^ reprPred (getPortnData portn) ^
	    "\n     \\<Longrightarrow> \\<exists> locs_p . (set locs_p) \\<subseteq> (set locs)" ^
	    "\n             \\<and> (l,h,locs_p) \\<in> " ^ pred ^ "\"\n"
	val portionsIfRepr =
	    mini_comment "is this needed? how to prove it?" ^ "\n" ^
	    "lemma " ^ consPostfix (getPortnLastCons (hd portns)) ^ "_portions_if_models:\n" ^
	    concat (map portionIfRepr (ListPair.zip (portns, preds))) ^
	    "sorry\n"
	(* intro lemmas for models_<type>_portion_<zaddr> *)
	fun portionIntro (portn, pred) consNm =
	    let
		val (dataNm, tag, flds, fld_portns, locs) =
		    portn_cons_bits (portn, pred) consNm
	    in
		"lemma " ^ portRuleNm portn consNm ^ "_intro:" ^
		"\n  \" " ^ dataNm_tag_on_hl dataNm tag ^
		wrap_list "\n  \\<and>" (map (fld_portion_on_hl consFM) fld_portns) "" ^
		"\n  \\<Longrightarrow> (l, h, " ^ locs ^ ")" ^
		" \\<in> " ^ pred ^ "\"" ^
		"\nby (rule " ^ portRuleNm portn consNm ^ ", fastsimp)\n"
	    end
	fun portionIntros (portn, pred) =
	    let
		val (consNms, hfConsNms) = FMfind_err (dataFM, getPortnData portn) "portionIntros"
	    in
		glue_list "\n" (map (portionIntro (portn, pred)) consNms)
	    end
	val portionsIntros =
	    mini_comment "introduction lemmas" ^ "\n" ^
	    glue_list "\n" (map portionIntros (ListPair.zip (portns, preds)))
		      
	(* elim lemmas for models_<type>_portion_<zaddr> *)
	fun portionElim (portn, pred) =
	    let
		val dataNm = getPortnData portn
		val (consNms, hfConsNms) = FMfind_err (dataFM, dataNm) "in portionElim"
		fun on_hl consNm =
		    let
			val nl = "\n      "
			val (dataNm, tag, flds, fld_portns, locs) =
			    portn_cons_bits (portn, pred) consNm
		    in
			tag_on_hl tag ^
			(case flds of
			     [] => " \\<and> locs = " ^ locs
			   | _ =>
			     nl ^ "\\<and> (\\<exists>" ^
			     wrap_list " l_" flds "" ^
			     wrap_list " locs_" flds "" ^ "." ^
			     nl ^ "  locs = " ^ locs ^
			     wrap_list (nl ^ "  \\<and> ") (map (fld_portion_on_hl consFM) fld_portns) "" ^ ")")
		    end
	    in
		"lemma " ^ portnPred portn ^ "_elim:" ^
		"\n  \"(l, h, locs) \\<in> " ^ pred ^
		"\n  \\<Longrightarrow> " ^ dataNm_on_hl dataNm ^
		"\n  \\<and> (" ^
		glue_wrap_list "\n     \\<or>\n     " "(" (map on_hl consNms) ")" ^
		")\"" ^
		"\n  apply (erule " ^ allPortnPreds ^ ".elims)" ^
		"\n  apply auto" ^
		"\n  done\n"
	    end
	val portionsElims = 
	    mini_comment "elimination lemmas" ^ "\n" ^
	    glue_list "\n" (map portionElim (ListPair.zip (portns, preds)))
    in
	subsection ("portions ending with " ^ final_consNm) ^
	glue_list "\n" [portionsConsts,
			portionsRules,
			portionsIntros,
			portionsElims,
			portionsIfRepr]
    end

    
fun portionStuff (params as (dataFM, dataNms, consFM, consNms, all_portns, unfoldFM)) =
    let
	(* portion grouping by last consNm *)
	val consPortnsFM =
	    let
		fun insert_portn (portn, done_fm) =
		    let
			val consNm = List.last portn
			val (prev_portns, prev_preds) =
			    case FM.peek (done_fm, consNm) of
				NONE => ([],[])
			      | SOME (a,b) => (a,b)
		    in
			FM.insert (done_fm, consNm,
				   (portn :: prev_portns,
				    (portnPred portn) :: prev_preds))
		    end
	    in
		foldl insert_portn (FM.mkDict String.compare) all_portns
	    end
    in
	section "Portions of datatype values" ^
	glue_list "\n" (map (portionsPerCons consPortnsFM params) consNms)
    end


(******************************************)
(* Internal separation of datatype values *)
(******************************************)

fun alongPred recAddr portn =
    "sep_along_" ^ recAddrNm recAddr ^ "_portion_" ^ portnNm portn
    
fun aloPred portn dataNm pre_addr =
    alongPred (complete_portn pre_addr dataNm) portn
	       (* (join_portns pre_addr portn) *)

fun alongRuleName recAddr portn consNm =
    "ALONG_" ^ recAddrNm recAddr ^ 
    "_portion_" ^ portnNm portn ^
    "_" ^ consPostfix consNm
   
fun aloRuleName portn dataNm pre_addr consNm =
    alongRuleName (complete_portn pre_addr dataNm) portn consNm
		   (* (join_portns pre_addr portn) *)

fun acrossPred recAddr portn =
    "sep_across_" ^ recAddrNm recAddr ^ "_portion_" ^ portnNm portn
    
fun acrPred portn dataNm pre_addr =
    acrossPred (complete_portn pre_addr dataNm) portn
	       (* (join_portns pre_addr portn) *)

fun acrossRuleName recAddr portn consNm =
    "ACROSS_" ^ recAddrNm recAddr ^
    "_portion_" ^ portnNm portn ^
    "_" ^ consPostfix consNm
    
fun acrRuleName portn dataNm pre_addr consNm =
    acrossRuleName (complete_portn pre_addr dataNm) portn consNm
		   (* (join_portns pre_addr portn) *)

fun fld_in_set setNm fldNm =
    "(l_" ^ fldNm ^ ",h) \\<in> " ^ setNm
    
fun fld_loc_in_set setNm fldNm =
    fld_loc fldNm ^ " \\<and> " ^ fld_in_set setNm fldNm

fun fld_pre_portion_across consFM portn (fld, pre_addr) =
    let
	val dataNm = getPortnData pre_addr
	val sub_pre_addr = preAddrTail dataNm consFM pre_addr
	val pred = acrPred portn (getPortnData portn) sub_pre_addr
    in
	fld_loc_in_set pred fld
    end
    
fun fld_pre_portion_along consFM portn (fld, pre_addr) =
    let
	val dataNm = getPortnData pre_addr
	val sub_pre_addr = preAddrTail dataNm consFM pre_addr
	val pred = aloPred portn (getPortnData portn) sub_pre_addr
    in
	fld_loc_in_set pred fld
    end
    
fun fld_pre_portion_locs consFM portn (fld, pre_addr) =
    let
	val dataNm = getPortnData pre_addr
	val sub_pre_addr = preAddrTail dataNm consFM pre_addr
	val pred = portnPred (join_portns sub_pre_addr portn)
    in
	fld_locs_in_set pred fld
    end

fun locs_disjoint_locs (locs1, locs2) =
    "(set " ^ locs1 ^ ") \\<inter> (set " ^ locs2 ^ ") = {}"
    
fun fld_disjoint_locs fld locs =
    locs_disjoint_locs ("locs_" ^ fld, locs)

fun locs_disjoint_fld locs fld =
    locs_disjoint_locs (locs, "locs_" ^ fld)

fun fld_disjoint_fld (fld1, fld2) =
    locs_disjoint_locs ("locs_" ^ fld1, "locs_" ^ fld2)

fun pre_addr_cons_bits consFM unfoldFM portn (pre_addr, pred) consNm =
    let 
	val (dataNm, _, _, tag) = FMfind_err (consFM, consNm) "in pre_addr_cons_bits (consNm)"
	val unf_pre_addrs =
	    if isTopPreAddr pre_addr
	    then FMfind_err (unfoldFM, getPortnData pre_addr) "in pre_addr_cons_bits (unfoldFM)"
	    else [pre_addr]
	val this_cons_pre_addrs =
	    List.filter (fn p => getPortn1stCons p = consNm) unf_pre_addrs
	val flds =
	    map (fn p => getPortnFld p) this_cons_pre_addrs
	val fld_pre_addrs =
	    ListPair.zip (flds, this_cons_pre_addrs)
    in
	(dataNm, tag, flds, fld_pre_addrs)
    end


fun separationAlong (dataFM, consFM, all_portns, unfoldFM) dataNm portn =
    let
	val trivial =
	    let
		val unfoldings = case FM.peek (unfoldFM, dataNm) of NONE => []
								  | SOME l => l
		fun portn_is_prefix pre_portn =
		    isPrefix portn pre_portn
		fun different_branch pre_portn =
		    getPortn1stCons pre_portn <> getPortn1stCons portn
	    in
		List.exists portn_is_prefix unfoldings
		orelse
		List.all different_branch unfoldings
	    end
	val pre_addrs =
	    map dropLastAddrBit
		(List.filter (fn p => getPortn1stCons portn = getPortnLastCons p) all_portns)
    in
	if trivial
	then
	    let
		fun f pre_addr =
		    let
			val pred = aloPred portn dataNm pre_addr
		    in
			"  " ^ pred ^ " :: int_sep_pred\n" ^
			"  \"" ^ pred ^ " \\<equiv> ?\" (* help! *)\n"
		    end
	    in
		"constdefs\n" ^
		glue_list "" (map f pre_addrs) ^
		"\n"
	    end
	else
	    let
		val preds = map (aloPred portn dataNm) pre_addrs
				(* sep_... declarations *)
		val sepConsts =
		    "consts\n" ^
		    wrap_list "  " (map (aloPred portn dataNm) pre_addrs) " :: int_sep_pred\n"
 			      (* sep_... definitions *)
		fun sepRule (pre_addr, pred) consNm =
		    let
			val (cons_dataNm, tag, flds, fld_pre_addrs) =
			    pre_addr_cons_bits consFM unfoldFM portn (pre_addr, pred) consNm
		    in
			aloRuleName portn dataNm pre_addr consNm ^ ":" ^
			"\n  \"\\<lbrakk> " ^ dataNm_tag_on_hl cons_dataNm tag ^
			wrap_list "\n  \\<and> " (map (fld_pre_portion_along consFM portn) fld_pre_addrs) "" ^
			(if isTopPreAddr pre_addr andalso length flds > 0
			 then
			     "\n  \\<and> (l,h,locs) \\<in> " ^ portnPred portn ^
			     wrap_list "\n  \\<and> " (map (fld_pre_portion_locs consFM portn) fld_pre_addrs) "" ^
			     wrap_list "\n  \\<and> " (map (locs_disjoint_fld "locs") flds) ""
			 else "") ^
			"\\<rbrakk>\n  \\<Longrightarrow> (l, h)" ^
			" \\<in> " ^ pred ^ "\"\n"
		    end
		fun sepRulesPerPre (pre_addr, pred) =
		    let
			val (consNms, hfConsNms) =
			    FMfind_err (dataFM, getPortnData pre_addr) "in sepRulesPerPre"
		    in
			mini_comment ("portion prefix " ^ portnNm pre_addr) ^ "\n" ^
			glue_list "" (map (sepRule (pre_addr, pred)) consNms)
		    end
		val sepRules =
		    "inductive " ^ glue_list " " (map (aloPred portn dataNm) pre_addrs) ^ " intros\n" ^
		    glue_list "" (map sepRulesPerPre (ListPair.zip (pre_addrs, preds)))
		val allSepPreds = 	    
		    glue_list "_" preds
			      (* intro lemmas for sep_... *)
		val sepIntros = ""
				    (* elim lemmas for sep_... *)
		val sepElims = ""
	    in
		glue_list "\n" [sepConsts,
				sepRules,
				sepIntros,
				sepElims]
	    end
    end

fun separationAcross (dataFM, consFM, all_portns, unfoldFM) dataNm portn =
    let
	val unfoldings = case FM.peek (unfoldFM, dataNm) of NONE => []
							  | SOME l => l
	val unfss = factor getPortn1stCons String.compare unfoldings
	val trivial = List.all (fn (_,l) => length l < 2) unfss
	val pre_addrs =
	    map dropLastAddrBit
		(List.filter (fn p => getPortn1stCons portn = getPortnLastCons p) all_portns)
    in
	if trivial
	then
	    let
		fun f pre_addr =
		    let
			val pred = acrPred portn dataNm pre_addr
		    in
			"  " ^ pred ^ " :: int_sep_pred\n" ^
			"  \"" ^ pred ^ " \\<equiv> ?\" (* help! *)\n"
		    end
	    in
		"constdefs\n" ^
		glue_list "" (map f pre_addrs) ^
		"\n"
	    end
	else
	    let
		val preds = map (acrPred portn dataNm) pre_addrs
				(* sep_... declarations *)
		val sepConsts =
		    "consts\n" ^
		    wrap_list "  " (map (acrPred portn dataNm) pre_addrs) " :: int_sep_pred\n"
 			      (* sep_... definitions *)
		fun sepRule (pre_addr, pred) consNm =
		    let
			val (cons_dataNm, tag, flds, fld_pre_addrs) =
			    pre_addr_cons_bits consFM unfoldFM portn (pre_addr, pred) consNm
		    in
			acrRuleName portn dataNm pre_addr consNm ^ ":" ^
			"\n  \"\\<lbrakk> " ^ dataNm_tag_on_hl cons_dataNm tag ^
			wrap_list "\n  \\<and> " (map (fld_pre_portion_across consFM portn) fld_pre_addrs) "" ^
			(if isTopPreAddr pre_addr
			 then
			     wrap_list "\n  \\<and> " (map (fld_pre_portion_locs consFM portn) fld_pre_addrs) "" ^
			     wrap_list "\n  \\<and> " (map fld_disjoint_fld (pairs_ord flds)) ""
			 else "") ^
			"\\<rbrakk>\n  \\<Longrightarrow> (l, h)" ^
			" \\<in> " ^ pred ^ "\"\n"
		    end
		fun sepRulesPerPre (pre_addr, pred) =
		    let
			val (consNms, hfConsNms) =
			    FMfind_err (dataFM, getPortnData pre_addr) "in sepRulesPerPre (across)"
		    in
			mini_comment ("portion prefix " ^ portnNm pre_addr) ^ "\n" ^
			glue_list "" (map (sepRule (pre_addr, pred)) consNms)
		    end
		val sepRules =
		    "inductive " ^ glue_list " " (map (acrPred portn dataNm) pre_addrs) ^ " intros\n" ^
		    glue_list "" (map sepRulesPerPre (ListPair.zip (pre_addrs, preds)))
		val allSepPreds = 	    
		    glue_list "_" preds
			      (* intro lemmas for sep_... *)
		val sepIntros = ""
				    (* elim lemmas for sep_... *)
		val sepElims = ""
	    in
		glue_list "\n" [sepConsts,
				sepRules,
				sepIntros,
				sepElims]
	    end
    end

fun separationAlongAcross params portn =
    let
	val dataNm = getPortnData portn
    in
	subsection ("separation of sub-portion " ^ portnNm portn ^ " along/across " ^ dataNm ) ^
	separationAlong params dataNm portn ^
	separationAcross params dataNm portn
    end
    
    
fun separationStuff (dataFM, consFM, portns, unfoldFM) =
    let
	val sepType =
	    "types\n" ^
	    "  int_sep_pred = \"(locn \\<times> heap) set\"\n"
    in
	section "Internal separation of datatype values" ^
	sepType ^ "\n" ^
	glue_list "" (map (separationAlongAcross (dataFM, consFM, portns, unfoldFM)) portns)
    end
	
(***********************************************)
(*** the top level theory creation function ****)
(***********************************************)

(* mkTheory: A.TypeDec list -> string *)
fun mkTheory typeDecLs0 level =
    let
	val typeDecLs = addFldDecs typeDecLs0 (* delete in future *)
	(* get various lists and maps from the datatype defs *)
	val (dataNms, dataFM, consNms, consFM, _, _, rfldNms, portns, unfoldFM, long_portns) =
	    getDataConsPortnInfo typeDecLs level
    in
	reprStuff (rfldNms, dataFM, dataNms, consFM, consNms)
	^ (if level > 1 then
	       portionStuff (dataFM, dataNms, consFM, consNms, portns @ long_portns, unfoldFM)
	       ^ (if level > 2 then
		      separationStuff (dataFM, consFM, portns, unfoldFM)
		  else "")
	  else "")
    end

end
