huff.mx raw

   1  // Copyright 2023 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package zstd
   6  
   7  import (
   8  	"io"
   9  	"math/bits"
  10  )
  11  
  12  // maxHuffmanBits is the largest possible Huffman table bits.
  13  const maxHuffmanBits = 11
  14  
  15  // readHuff reads Huffman table from data starting at off into table.
  16  // Each entry in a Huffman table is a pair of bytes.
  17  // The high byte is the encoded value. The low byte is the number
  18  // of bits used to encode that value. We index into the table
  19  // with a value of size tableBits. A value that requires fewer bits
  20  // appear in the table multiple times.
  21  // This returns the number of bits in the Huffman table and the new offset.
  22  // RFC 4.2.1.
  23  func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
  24  	if off >= len(data) {
  25  		return 0, 0, r.makeEOFError(off)
  26  	}
  27  
  28  	hdr := data[off]
  29  	off++
  30  
  31  	var weights [256]uint8
  32  	var count int
  33  	if hdr < 128 {
  34  		// The table is compressed using an FSE. RFC 4.2.1.2.
  35  		if len(r.fseScratch) < 1<<6 {
  36  			r.fseScratch = []fseEntry{:1<<6}
  37  		}
  38  		fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
  39  		if err != nil {
  40  			return 0, 0, err
  41  		}
  42  		fseTable := r.fseScratch
  43  
  44  		if off+int(hdr) > len(data) {
  45  			return 0, 0, r.makeEOFError(off)
  46  		}
  47  
  48  		rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
  49  		if err != nil {
  50  			return 0, 0, err
  51  		}
  52  
  53  		state1, err := rbr.val(uint8(fseBits))
  54  		if err != nil {
  55  			return 0, 0, err
  56  		}
  57  
  58  		state2, err := rbr.val(uint8(fseBits))
  59  		if err != nil {
  60  			return 0, 0, err
  61  		}
  62  
  63  		// There are two independent FSE streams, tracked by
  64  		// state1 and state2. We decode them alternately.
  65  
  66  		for {
  67  			pt := &fseTable[state1]
  68  			if !rbr.fetch(pt.bits) {
  69  				if count >= 254 {
  70  					return 0, 0, rbr.makeError("Huffman count overflow")
  71  				}
  72  				weights[count] = pt.sym
  73  				weights[count+1] = fseTable[state2].sym
  74  				count += 2
  75  				break
  76  			}
  77  
  78  			v, err := rbr.val(pt.bits)
  79  			if err != nil {
  80  				return 0, 0, err
  81  			}
  82  			state1 = uint32(pt.base) + v
  83  
  84  			if count >= 255 {
  85  				return 0, 0, rbr.makeError("Huffman count overflow")
  86  			}
  87  
  88  			weights[count] = pt.sym
  89  			count++
  90  
  91  			pt = &fseTable[state2]
  92  
  93  			if !rbr.fetch(pt.bits) {
  94  				if count >= 254 {
  95  					return 0, 0, rbr.makeError("Huffman count overflow")
  96  				}
  97  				weights[count] = pt.sym
  98  				weights[count+1] = fseTable[state1].sym
  99  				count += 2
 100  				break
 101  			}
 102  
 103  			v, err = rbr.val(pt.bits)
 104  			if err != nil {
 105  				return 0, 0, err
 106  			}
 107  			state2 = uint32(pt.base) + v
 108  
 109  			if count >= 255 {
 110  				return 0, 0, rbr.makeError("Huffman count overflow")
 111  			}
 112  
 113  			weights[count] = pt.sym
 114  			count++
 115  		}
 116  
 117  		off += int(hdr)
 118  	} else {
 119  		// The table is not compressed. Each weight is 4 bits.
 120  
 121  		count = int(hdr) - 127
 122  		if off+((count+1)/2) >= len(data) {
 123  			return 0, 0, io.ErrUnexpectedEOF
 124  		}
 125  		for i := 0; i < count; i += 2 {
 126  			b := data[off]
 127  			off++
 128  			weights[i] = b >> 4
 129  			weights[i+1] = b & 0xf
 130  		}
 131  	}
 132  
 133  	// RFC 4.2.1.3.
 134  
 135  	var weightMark [13]uint32
 136  	weightMask := uint32(0)
 137  	for _, w := range weights[:count] {
 138  		if w > 12 {
 139  			return 0, 0, r.makeError(off, "Huffman weight overflow")
 140  		}
 141  		weightMark[w]++
 142  		if w > 0 {
 143  			weightMask += 1 << (w - 1)
 144  		}
 145  	}
 146  	if weightMask == 0 {
 147  		return 0, 0, r.makeError(off, "bad Huffman weights")
 148  	}
 149  
 150  	tableBits = 32 - bits.LeadingZeros32(weightMask)
 151  	if tableBits > maxHuffmanBits {
 152  		return 0, 0, r.makeError(off, "bad Huffman weights")
 153  	}
 154  
 155  	if len(table) < 1<<tableBits {
 156  		return 0, 0, r.makeError(off, "Huffman table too small")
 157  	}
 158  
 159  	// Work out the last weight value, which is omitted because
 160  	// the weights must sum to a power of two.
 161  	left := (uint32(1) << tableBits) - weightMask
 162  	if left == 0 {
 163  		return 0, 0, r.makeError(off, "bad Huffman weights")
 164  	}
 165  	highBit := 31 - bits.LeadingZeros32(left)
 166  	if uint32(1)<<highBit != left {
 167  		return 0, 0, r.makeError(off, "bad Huffman weights")
 168  	}
 169  	if count >= 256 {
 170  		return 0, 0, r.makeError(off, "Huffman weight overflow")
 171  	}
 172  	weights[count] = uint8(highBit + 1)
 173  	count++
 174  	weightMark[highBit+1]++
 175  
 176  	if weightMark[1] < 2 || weightMark[1]&1 != 0 {
 177  		return 0, 0, r.makeError(off, "bad Huffman weights")
 178  	}
 179  
 180  	// Change weightMark from a count of weights to the index of
 181  	// the first symbol for that weight. We shift the indexes to
 182  	// also store how many we have seen so far,
 183  	next := uint32(0)
 184  	for i := 0; i < tableBits; i++ {
 185  		cur := next
 186  		next += weightMark[i+1] << i
 187  		weightMark[i+1] = cur
 188  	}
 189  
 190  	for i, w := range weights[:count] {
 191  		if w == 0 {
 192  			continue
 193  		}
 194  		length := uint32(1) << (w - 1)
 195  		tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
 196  		start := weightMark[w]
 197  		for j := uint32(0); j < length; j++ {
 198  			table[start+j] = tval
 199  		}
 200  		weightMark[w] += length
 201  	}
 202  
 203  	return tableBits, off, nil
 204  }
 205