(* Enivronments for variable types *)

open Normsyn NAsyntfn
local
structure S = Splaymap
in

val required = ()

fun env_error s = Util.ierror ("[Env.sml]: " ^ s)

datatype envType = MAINenv | CLASSenv of string

type varEnv = (string, Ty) S.dict
type funEnv = (string, Ty * varEnv) S.dict  (* function name, function type, types of fn vars *)

type classEnv = {class: envType, funenv: funEnv}
datatype progEnv = ProgEnv of classEnv list

fun getMainEnv (ProgEnv env) =
    case List.find (fn e => #class e = MAINenv) env of
	SOME en => #funenv en
      | NONE => env_error "Couldn't find main environment"

fun getClassEnv cname (ProgEnv env)=
    case List.find (fn e => #class e = CLASSenv cname) env of
	SOME en => #funenv en
      | NONE => env_error ("Couldn't find environment for class " ^ cname)

fun getFunTy fname env =
    case S.peek (env, fname) of
	SOME (ty, vars) => ty
      | NONE => env_error ("Couldn't find type for function " ^ fname)

fun getVarEnv fname env =
    case S.peek (env, fname) of
	SOME (ty, vars) => vars
      | NONE => env_error ("Couldn't find variable types for function " ^ fname)

fun getVarTy v env =
    case S.peek (env, v) of
	SOME ty => ty
      | NONE => env_error ("Couldn't find type of variable " ^ v)

fun getVarTyOpt v env = S.peek (env, v)

fun isLocal v env = Option.isSome (getVarTyOpt v env)

fun newVarEnv () = S.mkDict (String.compare)
fun newFunEnv () = S.mkDict (String.compare)



fun prVarEnv e =
    Splaymap.app (fn (v,t) => print ("    " ^ v ^ ": " ^ typeToString t ^ "\n")) e


fun prFunEnv (fname, (fty, varTys)) = (
    print (" " ^ fname ^ ": " ^ typeToString fty ^ "\n {\n");
    prVarEnv varTys;
    print " }\n\n"
    )

fun prEnvInner e =
    let
	val () = case #class e
		  of MAINenv => print "Main environment:\n"
		   | CLASSenv t => print ("Environment for class " ^ t ^ ":\n")
    in
	Splaymap.app prFunEnv (#funenv e)
    end

fun prEnv s =
    case s of
	NONE => ()
      | SOME (ProgEnv e) =>
	(
	 print "(* ++++++++++++ Environment ++++++++++++ *)\n(*\n";
	 app prEnvInner e;
	 print "\n*)\n";
	 print "(* ++++++++++++++++ end ++++++++++++++++ *)\n"
	)

end

