(********************************************************************)
(*                                                                  *)
(*  inflate.s7i   Inflate uncompression algorithm                   *)
(*  Copyright (C) 2008, 2013, 2015, 2017  Thomas Mertes             *)
(*                2022 - 2025  Thomas Mertes                        *)
(*                                                                  *)
(*  This file is part of the Seed7 Runtime Library.                 *)
(*                                                                  *)
(*  The Seed7 Runtime Library is free software; you can             *)
(*  redistribute it and/or modify it under the terms of the GNU     *)
(*  Lesser General Public License as published by the Free Software *)
(*  Foundation; either version 2.1 of the License, or (at your      *)
(*  option) any later version.                                      *)
(*                                                                  *)
(*  The Seed7 Runtime Library is distributed in the hope that it    *)
(*  will be useful, but WITHOUT ANY WARRANTY; without even the      *)
(*  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR *)
(*  PURPOSE.  See the GNU Lesser General Public License for more    *)
(*  details.                                                        *)
(*                                                                  *)
(*  You should have received a copy of the GNU Lesser General       *)
(*  Public License along with this program; if not, write to the    *)
(*  Free Software Foundation, Inc., 51 Franklin Street,             *)
(*  Fifth Floor, Boston, MA  02110-1301, USA.                       *)
(*                                                                  *)
(********************************************************************)


include "bitdata.s7i";
include "bytedata.s7i";
include "huffman.s7i";


const integer: INFLATE_END_OF_BLOCK is 256;


const proc: getNonCompressedBlock (inout lsbInBitStream: compressedStream,
    inout string: uncompressed) is func
  local
    var integer: length is 0;
    var integer: nlength is 0;
    var string: uncompressedBlock is "";
  begin
    length := bytes2Int(gets(compressedStream, 2), UNSIGNED, LE);
    nlength := bytes2Int(gets(compressedStream, 2), UNSIGNED, LE);
    uncompressedBlock := gets(compressedStream, length);
    if length(uncompressedBlock) <> length then
      raise RANGE_ERROR;
    else
      uncompressed &:= uncompressedBlock;
    end if;
  end func;


const func integer: getLiteralOrLength (inout lsbInBitStream: compressedStream) is func
  result
    var integer: literalOrLength is 0;
  begin
    literalOrLength := reverseBits[7][getBits(compressedStream, 7)];
    if literalOrLength <= 2#0010111 then
      literalOrLength +:= 256;
    else
      literalOrLength <<:= 1;
      literalOrLength +:= getBit(compressedStream);
      if literalOrLength <= 2#10111111 then
        literalOrLength -:= 2#00110000;
      elsif literalOrLength <= 2#11000111 then
        literalOrLength +:= 280 - 2#11000000;
      else
        literalOrLength <<:= 1;
        literalOrLength +:= getBit(compressedStream);
        literalOrLength -:= 2#110010000 - 144;
      end if;
    end if;
  end func;


const func integer: getDistance (inout lsbInBitStream: compressedStream) is
  return reverseBits[5][getBits(compressedStream, 5)];


const func integer: decodeLength (inout lsbInBitStream: compressedStream,
    in integer: code, in boolean: deflate64) is func
  result
    var integer: length is 0;
  begin
    if code <= 264 then
      length := code - 254;
    else
      case code of
        when {265}: length :=  11 + getBit(compressedStream);
        when {266}: length :=  13 + getBit(compressedStream);
        when {267}: length :=  15 + getBit(compressedStream);
        when {268}: length :=  17 + getBit(compressedStream);
        when {269}: length :=  19 + getBits(compressedStream, 2);
        when {270}: length :=  23 + getBits(compressedStream, 2);
        when {271}: length :=  27 + getBits(compressedStream, 2);
        when {272}: length :=  31 + getBits(compressedStream, 2);
        when {273}: length :=  35 + getBits(compressedStream, 3);
        when {274}: length :=  43 + getBits(compressedStream, 3);
        when {275}: length :=  51 + getBits(compressedStream, 3);
        when {276}: length :=  59 + getBits(compressedStream, 3);
        when {277}: length :=  67 + getBits(compressedStream, 4);
        when {278}: length :=  83 + getBits(compressedStream, 4);
        when {279}: length :=  99 + getBits(compressedStream, 4);
        when {280}: length := 115 + getBits(compressedStream, 4);
        when {281}: length := 131 + getBits(compressedStream, 5);
        when {282}: length := 163 + getBits(compressedStream, 5);
        when {283}: length := 195 + getBits(compressedStream, 5);
        when {284}: length := 227 + getBits(compressedStream, 5);
        when {285}:
          if deflate64 then
            length := 3 + getBits(compressedStream, 16);
          else
            length := 258;
          end if;
        otherwise: raise RANGE_ERROR;
      end case;
    end if;
  end func;


const func integer: decodeDistance (inout lsbInBitStream: compressedStream,
    in integer: code) is func
  result
    var integer: distance is 0;
  begin
    case code of
      when {0, 1, 2, 3}: distance := succ(code);
      when {4}:  distance :=     5 + getBit(compressedStream);
      when {5}:  distance :=     7 + getBit(compressedStream);
      when {6}:  distance :=     9 + getBits(compressedStream, 2);
      when {7}:  distance :=    13 + getBits(compressedStream, 2);
      when {8}:  distance :=    17 + getBits(compressedStream, 3);
      when {9}:  distance :=    25 + getBits(compressedStream, 3);
      when {10}: distance :=    33 + getBits(compressedStream, 4);
      when {11}: distance :=    49 + getBits(compressedStream, 4);
      when {12}: distance :=    65 + getBits(compressedStream, 5);
      when {13}: distance :=    97 + getBits(compressedStream, 5);
      when {14}: distance :=   129 + getBits(compressedStream, 6);
      when {15}: distance :=   193 + getBits(compressedStream, 6);
      when {16}: distance :=   257 + getBits(compressedStream, 7);
      when {17}: distance :=   385 + getBits(compressedStream, 7);
      when {18}: distance :=   513 + getBits(compressedStream, 8);
      when {19}: distance :=   769 + getBits(compressedStream, 8);
      when {20}: distance :=  1025 + getBits(compressedStream, 9);
      when {21}: distance :=  1537 + getBits(compressedStream, 9);
      when {22}: distance :=  2049 + getBits(compressedStream, 10);
      when {23}: distance :=  3073 + getBits(compressedStream, 10);
      when {24}: distance :=  4097 + getBits(compressedStream, 11);
      when {25}: distance :=  6145 + getBits(compressedStream, 11);
      when {26}: distance :=  8193 + getBits(compressedStream, 12);
      when {27}: distance := 12289 + getBits(compressedStream, 12);
      when {28}: distance := 16385 + getBits(compressedStream, 13);
      when {29}: distance := 24577 + getBits(compressedStream, 13);
      # The codes 30 and 31 are only used by enhanced deflate (deflate64).
      when {30}: distance := 32769 + getBits(compressedStream, 14);
      when {31}: distance := 49153 + getBits(compressedStream, 14);
      otherwise: raise RANGE_ERROR;
    end case;
  end func;


const proc: decodeFixedHuffmanCodes (inout lsbInBitStream: compressedStream,
    in boolean: deflate64, inout string: uncompressed) is func
  local
    var integer: literalOrLength is 0;
    var integer: length is 0;
    var integer: distance is 0;
    var integer: number is 0;
    var integer: nextPos is 0;
  begin
    literalOrLength := getLiteralOrLength(compressedStream);
    while literalOrLength <> INFLATE_END_OF_BLOCK do
      if literalOrLength < INFLATE_END_OF_BLOCK then
        uncompressed &:= chr(literalOrLength);
      else
        length := decodeLength(compressedStream, literalOrLength, deflate64);
        distance := getDistance(compressedStream);
        distance := decodeDistance(compressedStream, distance);
        if length > distance then
          nextPos := succ(length(uncompressed));
          uncompressed &:= "\0;" mult length;
          for number range nextPos to nextPos + length - 1 do
            uncompressed @:= [number] uncompressed[number - distance];
          end for;
        else # hopefully length(uncompressed) >= distance holds
          uncompressed &:= uncompressed[succ(length(uncompressed)) - distance fixLen length];
        end if;
      end if;
      literalOrLength := getLiteralOrLength(compressedStream);
    end while;
  end func;


const proc: decodeDynamicHuffmanCodes (inout lsbInBitStream: compressedStream,
    in lsbHuffmanDecoder: literalOrLengthDecoder,
    in lsbHuffmanDecoder: distanceDecoder, in boolean: deflate64,
    inout string: uncompressed) is func
  local
    var integer: literalOrLength is 0;
    var integer: length is 0;
    var integer: distance is 0;
    var integer: nextPos is 0;
    var integer: number is 0;
  begin
    literalOrLength := getHuffmanSymbol(compressedStream, literalOrLengthDecoder);
    while literalOrLength <> INFLATE_END_OF_BLOCK do
      if literalOrLength < INFLATE_END_OF_BLOCK then
        uncompressed &:= chr(literalOrLength);
      else
        length := decodeLength(compressedStream, literalOrLength, deflate64);
        distance := getHuffmanSymbol(compressedStream, distanceDecoder);
        distance := decodeDistance(compressedStream, distance);
        if length > distance then
          nextPos := succ(length(uncompressed));
          uncompressed &:= "\0;" mult length;
          for number range nextPos to nextPos + length - 1 do
            uncompressed @:= [number] uncompressed[number - distance];
          end for;
        else # hopefully length(uncompressed) >= distance holds
          uncompressed &:= uncompressed[succ(length(uncompressed)) - distance fixLen length];
        end if;
      end if;
      literalOrLength := getHuffmanSymbol(compressedStream, literalOrLengthDecoder);
    end while;
  end func;


const proc: decodeDynamicHuffmanCodes (inout lsbInBitStream: compressedStream,
    in boolean: deflate64, inout string: uncompressed) is func
  local
    const array integer: mapToOrderedLengths is []
        (16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15);
    var integer: literalOrLengthTableSize is 0;  # hlit + 257
    var integer: distanceTableSize is 0;         # hdist + 1
    var integer: combinedDataTableSize is 0;     # hclen + 4
    var array integer:     combinedDataCodeLengths is [0 .. 18] times 0;
    var lsbHuffmanDecoder: combinedDataDecoder is lsbHuffmanDecoder.value;
    var array integer:     combinedData is [0 .. -1] times 0;
    var array integer:     literalOrLengthCodeLengths is [0 .. -1] times 0;
    var lsbHuffmanDecoder: literalOrLengthDecoder is lsbHuffmanDecoder.value;
    var array integer:     distanceCodeLengths is [0 .. -1] times 0;
    var lsbHuffmanDecoder: distanceDecoder is lsbHuffmanDecoder.value;
    var integer: number is 0;
    var integer: combinedDataElement is 0;
    var integer: factor is 0;
  begin
    literalOrLengthTableSize := getBits(compressedStream, 5) + 257;
    distanceTableSize        := getBits(compressedStream, 5) + 1;
    combinedDataTableSize    := getBits(compressedStream, 4) + 4;
    for number range 1 to combinedDataTableSize do
      combinedDataCodeLengths[mapToOrderedLengths[number]] := getBits(compressedStream, 3);
    end for;
    combinedDataDecoder := createLsbHuffmanDecoder(combinedDataCodeLengths);
    number := 1;
    while number <= literalOrLengthTableSize + distanceTableSize do
      combinedDataElement := getHuffmanSymbol(compressedStream, combinedDataDecoder);
      if combinedDataElement <= 15 then
        combinedData &:= combinedDataElement;
        incr(number);
      elsif combinedDataElement = 16 then
        factor := getBits(compressedStream, 2) + 3;
        combinedData &:= factor times
            combinedData[maxIdx(combinedData)];
        number +:= factor;
      elsif combinedDataElement = 17 then
        factor := getBits(compressedStream, 3) + 3;
        combinedData &:= factor times 0;
        number +:= factor;
      else # combinedDataElement = 18
        factor := getBits(compressedStream, 7) + 11;
        combinedData &:= factor times 0;
        number +:= factor;
      end if;
    end while;
    literalOrLengthCodeLengths := combinedData[.. pred(literalOrLengthTableSize)];
    distanceCodeLengths        := combinedData[literalOrLengthTableSize .. ];
    literalOrLengthDecoder := createLsbHuffmanDecoder(literalOrLengthCodeLengths);
    distanceDecoder        := createLsbHuffmanDecoder(distanceCodeLengths);
    decodeDynamicHuffmanCodes(compressedStream, literalOrLengthDecoder,
                              distanceDecoder, deflate64, uncompressed);
  end func;


const proc: processCompressedBlock (inout lsbInBitStream: compressedStream,
    inout string: uncompressed, inout boolean: bfinal) is func
  local
    var integer: btype is 0;
  begin
    bfinal := odd(getBit(compressedStream));
    btype := getBits(compressedStream, 2);
    case btype of
      when {0}:
        getNonCompressedBlock(compressedStream, uncompressed);
      when {1}:
        decodeFixedHuffmanCodes(compressedStream, FALSE, uncompressed);
      when {2}:
        decodeDynamicHuffmanCodes(compressedStream, FALSE, uncompressed);
      otherwise:
        raise RANGE_ERROR;
    end case;
  end func;


(**
 *  Decompress a file that was compressed with DEFLATE.
 *  DEFLATE is a compression algorithm that uses a combination of
 *  the LZ77 algorithm and Huffman coding.
 *  @return the uncompressed string.
 *  @exception RANGE_ERROR If ''compressed'' is not in DEFLATE format.
 *)
const func string: inflate (inout file: compressed) is func
  result
    var string: uncompressed is "";
  local
    var lsbInBitStream: compressedStream is lsbInBitStream.value;
    var boolean: bfinal is FALSE;
  begin
    compressedStream := openLsbInBitStream(compressed);
    repeat
      processCompressedBlock(compressedStream, uncompressed, bfinal);
    until bfinal;
    close(compressedStream);
  end func;


(**
 *  Decompress a string that was compressed with DEFLATE.
 *  DEFLATE is a compression algorithm that uses a combination of
 *  the LZ77 algorithm and Huffman coding.
 *  @return the uncompressed string.
 *  @exception RANGE_ERROR If ''compressed'' is not in DEFLATE format.
 *)
const func string: inflate (in string: compressed) is func
  result
    var string: uncompressed is "";
  local
    var lsbInBitStream: compressedStream is lsbInBitStream.value;
    var boolean: bfinal is FALSE;
  begin
    compressedStream := openLsbInBitStream(compressed);
    repeat
      processCompressedBlock(compressedStream, uncompressed, bfinal);
    until bfinal;
  end func;


const proc: processCompressedBlock64 (inout lsbInBitStream: compressedStream,
    inout string: uncompressed, inout boolean: bfinal) is func
  local
    var integer: btype is 0;
  begin
    bfinal := odd(getBit(compressedStream));
    btype := getBits(compressedStream, 2);
    case btype of
      when {0}:
        getNonCompressedBlock(compressedStream, uncompressed);
      when {1}:
        decodeFixedHuffmanCodes(compressedStream, TRUE, uncompressed);
      when {2}:
        decodeDynamicHuffmanCodes(compressedStream, TRUE, uncompressed);
      otherwise:
        raise RANGE_ERROR;
    end case;
  end func;


(**
 *  Decompress a file that was compressed with DEFLATE64.
 *  DEFLATE64 is a compression algorithm that uses a combination of
 *  the LZ77 algorithm and Huffman coding.
 *  @return the uncompressed string.
 *  @exception RANGE_ERROR If ''compressed'' is not in DEFLATE64 format.
 *)
const func string: inflate64 (inout file: compressed) is func
  result
    var string: uncompressed is "";
  local
    var lsbInBitStream: compressedStream is lsbInBitStream.value;
    var boolean: bfinal is FALSE;
  begin
    compressedStream := openLsbInBitStream(compressed);
    repeat
      processCompressedBlock(compressedStream, uncompressed, bfinal);
    until bfinal;
    close(compressedStream);
  end func;


(**
 *  Decompress a string that was compressed with DEFLATE64.
 *  DEFLATE64 is a compression algorithm that uses a combination of
 *  the LZ77 algorithm and Huffman coding.
 *  @return the uncompressed string.
 *  @exception RANGE_ERROR If ''compressed'' is not in DEFLATE64 format.
 *)
const func string: inflate64 (in string: compressed) is func
  result
    var string: uncompressed is "";
  local
    var lsbInBitStream: compressedStream is lsbInBitStream.value;
    var boolean: bfinal is FALSE;
  begin
    compressedStream := openLsbInBitStream(compressed);
    repeat
      processCompressedBlock(compressedStream, uncompressed, bfinal);
    until bfinal;
  end func;