(*
 Inflate.sml
 K.Mackenzie,  May 2004
*)

(* Decoder for deflate algorithm.  See RFC 1951 for details. *)
(* This implementation is fairly slow.  Perhaps we could speed
   things up by using a more compact (but probably less readable)
   implementation of Huffman trees. *)


exception InflateError of string

datatype tree = Empty | Node of tree * tree | Leaf of int

fun insert x [] Empty = Leaf x
  | insert x (false::t) (Node(l,r)) = Node (insert x t l, r)
  | insert x (true::t)  (Node(l,r)) = Node (l, insert x t r)
  | insert x (false::t) Empty = Node (insert x t Empty, Empty)
  | insert x (true::t)  Empty = Node (Empty, insert x t Empty)
  | insert x _ _ = raise InflateError ("Trying to insert value "
				       ^ Int.toString x
				       ^ " at bad location in Huffman tree")

val andb = Word8.andb
val >> = Word8.>>
infix 4 andb >>

val $ = Array.sub (* makes array operations a little more readable *)
infix 4 $

fun bit_set j n = (Word.andb(Word.fromInt n, Word.<<(0w1, Word.fromInt j))) <> 0w0

fun bit_list n len =
    if  len < 0 then []
    else (bit_set len n)::(bit_list n (len-1))

fun makeTree bitlen_list =
let
    val maxbits = 16  (* ??? FIX  *)
    val counts = Array.array (maxbits, 0)     (* # of codewords of each length *)
    val codewords = Array.array (maxbits, 0)  (* jth entry = current codeword of length j *)

    fun incr a n = Array.update (a, n,(a $ n) + 1)
    val () = app (incr counts) bitlen_list

    fun get_init j codebase prev n = (* calculate initial codeword for each length *)
	if j = n then ()
	else
	    let
		val t = counts $ j
		val nxt_codebase = (codebase + prev) * 2
		val () = Array.update (codewords,j, nxt_codebase)
	    in
		get_init (j+1) nxt_codebase t n
	    end

    val () = get_init 1 0 0 maxbits

    fun mktree [] _ tr = tr
      | mktree (nbits::t) n tr =
	let
	    val len = counts $  nbits
	    val c =   codewords $ nbits
	    val tr' =
		if nbits = 0 then tr else let
		    val () = Array.update (codewords, nbits, c+1)
		in
		    insert n (bit_list c (nbits-1)) tr
		end
	in
	    mktree t (n+1) tr'
	end
in
    mktree bitlen_list 0 Empty
end

fun dupn x 0 acc = acc | dupn x n acc = dupn x (n-1) (x::acc)
infixr 5 @@
fun (x,y) @@ z = dupn x y z

val default_layout = (8,144)@@(9,112)@@(7,24)@@(8,8)@@[]

fun fixed_Huffman() = (makeTree default_layout, makeTree (dupn 5 32 []))

fun inflate (input: Word8Vector.vector) (output_size: int) =
let
    val out = Word8Array.array(output_size, Word8.fromInt 126)

    val outPtr = ref 0
    fun incrOutPtr n =  outPtr := !outPtr + n
    fun push x = (Word8Array.update(out, !outPtr, x); outPtr := !outPtr + 1)

    val bytePtr = ref 0
    val bitPtr  = ref 0
    val thisByte = ref (Word8Vector.sub(input, 0))

    fun nextByte () =
	(bytePtr := !bytePtr + 1; bitPtr := 0; thisByte := Word8Vector.sub(input, !bytePtr))
	handle Subscript => raise InflateError "Ran out of input"

    fun align () = if !bitPtr = 0 then () else nextByte ()

    fun nextBit () =
	let
	    val () = if !bitPtr = 8 then nextByte() else ()
	    (* Don't get the next byte until we really need it,
               in case we're at the end of the stream *)

	    val q = !thisByte andb
		     (case !bitPtr of
			  0 => 0wx01 | 1 => 0wx02 | 2 => 0wx04 | 3 => 0wx08
			| 4 => 0wx10 | 5 => 0wx20 | 6 => 0wx40 | 7 => 0wx80
			| _ => 0wx00)

	    val () = bitPtr := !bitPtr+1
	in
	    q <> 0w0
	end

    fun nextVal tree =
	case tree of
 	    Node(l,r) => if nextBit() then nextVal r else nextVal l
	  | Leaf v => v
	  | Empty => raise InflateError "Fell out of tree"

    local
	fun f n p a = (* Get n-bit int, least significant bit first *)
	    if n=0 then a
	    else let
		    val b = if nextBit() then 1 else 0
		in
		    f (n-1) (p*2) (a+b*p)
		end
    in
    fun getbitsRev n = f n 1 0
    end

    fun dynamic_Huffman () =
	let
	    val hlit =  getbitsRev 5
	    val hdist = getbitsRev 5
	    val hclen = getbitsRev 4

(*	    val code_order = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15] *)
	    val code_order = [3, 17, 15, 13, 11, 9, 7, 5, 4, 6, 8, 10, 12, 14, 16, 18, 0, 1, 2]

	    val code_lens_a = Array.array (19, 0)

	    fun get_code_lens j lim =
		if j = lim then ()
		else
		    let
			val bits = getbitsRev 3
			val () = Array.update (code_lens_a, j, bits)
		    in
			get_code_lens (j+1) lim
		    end

	    val () = get_code_lens 0 (hclen+4)

	    val code_lens = map (fn j => code_lens_a $ j) code_order

	    val cl_tree = makeTree (code_lens)

	    fun replast n l =
		case l of [] => raise InflateError "Attempting to replicate last element of empty list"
			| (h::t) => dupn h n l

	    fun getTrees n acc =
		if n < 0 then raise InflateError "Bad code encoding"
		else if n=0 then rev acc
		else
		    let
			val c = nextVal cl_tree
			val (acc', n_got) =
			    case c of 18 => let val n = (getbitsRev 7) + 11 in (dupn 0 n acc, n) end
				    | 17 => let val n = (getbitsRev 3) + 3  in (dupn 0 n acc, n) end
				    | 16 => let val n = (getbitsRev 2) + 3  in (replast n acc, n) end
				    | n => (n::acc, 1)
		    in
			getTrees (n-n_got) acc'
		    end

	    val p = getTrees (hlit+hdist+258) []

	    val ltree = makeTree (List.take (p, hlit+257))
	    val dtree = makeTree (List.drop (p, hlit+257))
	in
	    (ltree, dtree)
	end


    local
	val lens =   (* base values for lengths *)
	    #[3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
	      35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258]

	val lext = (* number of extra bits for lengths *)
	    #[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
	      3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0]

	val dists = (* base values for distances *)
	    #[ 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385,
	       513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577]

	val dext =  (* number of extra bits for distances *)
	    #[0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
	      7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13]

	fun get v n = Vector.sub(v,n)
	    handle Subscript => raise InflateError "Bad index in length/distance lookup table"

	fun decode_len v =  get lens v  + getbitsRev (get lext v)
	fun decode_dist v = get dists v + getbitsRev (get dext v)
    in

    fun inflateBlock () =
	let
	    fun decodeBlock (ltree, dtree) =
		let
		    fun copyBlock v =
			let
			    val len  = decode_len v
			    val dist = decode_dist (nextVal dtree)

			    fun mv n p q = (* Word8Array.copy is tempting, but doesn't do what we want *)
				if n=0 then ()
				else (Word8Array.update(out, q, Word8Array.sub (out, p));
				      mv (n-1) (p+1) (q+1))

			    val () = mv len (!outPtr-dist) (!outPtr)
				handle Subscript => raise InflateError "Bad index in copyBlock"
			in
			    incrOutPtr len
			end

		    fun loop() =
			let
			    val v = nextVal ltree
			in
			    if v = 256 then ()  (* end of block *)
			    else if v < 256
			    then (push (Word8.fromInt v); loop()) (* verbatim byte *)
			    else (copyBlock (v-257); loop())
			end
		in
		    loop()
		end


	    (* Here we go *)

	    val bfinal = getbitsRev 1  (* last block? *)
	    val btype  = getbitsRev 2  (* compression type for this block *)

	    val () =
		case btype of
		    0 => (* no compression *)
		    let
			val () = align ()  (* start of byte *)
			val len = (getbitsRev 8) + (getbitsRev 8) * 256
			val len' = (getbitsRev 8) + (getbitsRev 8) * 256
                        (* Should be bitwise complement of len,  but let's
			   not bother to check. *)

	                (* Note that at this point bytePtr points to the byte before
                           the first byte which we want to copy *)
			val () = Word8Array.copyVec {src = input, si = !bytePtr+1 , len = SOME len,
						     dst = out, di = !outPtr}

			val () = bytePtr := !bytePtr + len (* just past the data *)
			val () = incrOutPtr len
		    in
			()
		    end
		  | 1 => decodeBlock (fixed_Huffman())
		  | 2 => decodeBlock (dynamic_Huffman())
		  | _ => raise InflateError "Bad BTYPE"

	in
	    if bfinal = 1 then (* We've just done the last block *)
	       if !outPtr <> output_size
	       then raise InflateError "Decompressed data doesn't have expected size"
	       else
		   Word8Array.extract (out, 0, NONE) (* Convert it into a vector *)
	    else inflateBlock ()
	end   (* inflateBlock *)
    end       (* local *)
in
    inflateBlock ()
end
