fse.mx raw

   1  // Copyright 2023 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package zstd
   6  
   7  import (
   8  	"math/bits"
   9  )
  10  
  11  // fseEntry is one entry in an FSE table.
  12  type fseEntry struct {
  13  	sym  uint8  // value that this entry records
  14  	bits uint8  // number of bits to read to determine next state
  15  	base uint16 // add those bits to this state to get the next state
  16  }
  17  
  18  // readFSE reads an FSE table from data starting at off.
  19  // maxSym is the maximum symbol value.
  20  // maxBits is the maximum number of bits permitted for symbols in the table.
  21  // The FSE is written into table, which must be at least 1<<maxBits in size.
  22  // This returns the number of bits in the FSE table and the new offset.
  23  // RFC 4.1.1.
  24  func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) {
  25  	br := r.makeBitReader(data, off)
  26  	if err := br.moreBits(); err != nil {
  27  		return 0, 0, err
  28  	}
  29  
  30  	accuracyLog := int(br.val(4)) + 5
  31  	if accuracyLog > maxBits {
  32  		return 0, 0, br.makeError("FSE accuracy log too large")
  33  	}
  34  
  35  	// The number of remaining probabilities, plus 1.
  36  	// This determines the number of bits to be read for the next value.
  37  	remaining := (1 << accuracyLog) + 1
  38  
  39  	// The current difference between small and large values,
  40  	// which depends on the number of remaining values.
  41  	// Small values use 1 less bit.
  42  	threshold := 1 << accuracyLog
  43  
  44  	// The number of bits needed to compute threshold.
  45  	bitsNeeded := accuracyLog + 1
  46  
  47  	// The next character value.
  48  	sym := 0
  49  
  50  	// Whether the last count was 0.
  51  	prev0 := false
  52  
  53  	var norm [256]int16
  54  
  55  	for remaining > 1 && sym <= maxSym {
  56  		if err := br.moreBits(); err != nil {
  57  			return 0, 0, err
  58  		}
  59  
  60  		if prev0 {
  61  			// Previous count was 0, so there is a 2-bit
  62  			// repeat flag. If the 2-bit flag is 0b11,
  63  			// it adds 3 and then there is another repeat flag.
  64  			zsym := sym
  65  			for (br.bits & 0xfff) == 0xfff {
  66  				zsym += 3 * 6
  67  				br.bits >>= 12
  68  				br.cnt -= 12
  69  				if err := br.moreBits(); err != nil {
  70  					return 0, 0, err
  71  				}
  72  			}
  73  			for (br.bits & 3) == 3 {
  74  				zsym += 3
  75  				br.bits >>= 2
  76  				br.cnt -= 2
  77  				if err := br.moreBits(); err != nil {
  78  					return 0, 0, err
  79  				}
  80  			}
  81  
  82  			// We have at least 14 bits here,
  83  			// no need to call moreBits
  84  
  85  			zsym += int(br.val(2))
  86  
  87  			if zsym > maxSym {
  88  				return 0, 0, br.makeError("FSE symbol index overflow")
  89  			}
  90  
  91  			for ; sym < zsym; sym++ {
  92  				norm[uint8(sym)] = 0
  93  			}
  94  
  95  			prev0 = false
  96  			continue
  97  		}
  98  
  99  		max := (2*threshold - 1) - remaining
 100  		var count int
 101  		if int(br.bits&uint32(threshold-1)) < max {
 102  			// A small value.
 103  			count = int(br.bits & uint32((threshold - 1)))
 104  			br.bits >>= bitsNeeded - 1
 105  			br.cnt -= uint32(bitsNeeded - 1)
 106  		} else {
 107  			// A large value.
 108  			count = int(br.bits & uint32((2*threshold - 1)))
 109  			if count >= threshold {
 110  				count -= max
 111  			}
 112  			br.bits >>= bitsNeeded
 113  			br.cnt -= uint32(bitsNeeded)
 114  		}
 115  
 116  		count--
 117  		if count >= 0 {
 118  			remaining -= count
 119  		} else {
 120  			remaining--
 121  		}
 122  		if sym >= 256 {
 123  			return 0, 0, br.makeError("FSE sym overflow")
 124  		}
 125  		norm[uint8(sym)] = int16(count)
 126  		sym++
 127  
 128  		prev0 = count == 0
 129  
 130  		for remaining < threshold {
 131  			bitsNeeded--
 132  			threshold >>= 1
 133  		}
 134  	}
 135  
 136  	if remaining != 1 {
 137  		return 0, 0, br.makeError("too many symbols in FSE table")
 138  	}
 139  
 140  	for ; sym <= maxSym; sym++ {
 141  		norm[uint8(sym)] = 0
 142  	}
 143  
 144  	br.backup()
 145  
 146  	if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil {
 147  		return 0, 0, err
 148  	}
 149  
 150  	return accuracyLog, int(br.off), nil
 151  }
 152  
 153  // buildFSE builds an FSE decoding table from a list of probabilities.
 154  // The probabilities are in norm. next is scratch space. The number of bits
 155  // in the table is tableBits.
 156  func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error {
 157  	tableSize := 1 << tableBits
 158  	highThreshold := tableSize - 1
 159  
 160  	var next [256]uint16
 161  
 162  	for i, n := range norm {
 163  		if n >= 0 {
 164  			next[uint8(i)] = uint16(n)
 165  		} else {
 166  			table[highThreshold].sym = uint8(i)
 167  			highThreshold--
 168  			next[uint8(i)] = 1
 169  		}
 170  	}
 171  
 172  	pos := 0
 173  	step := (tableSize >> 1) + (tableSize >> 3) + 3
 174  	mask := tableSize - 1
 175  	for i, n := range norm {
 176  		for j := 0; j < int(n); j++ {
 177  			table[pos].sym = uint8(i)
 178  			pos = (pos + step) & mask
 179  			for pos > highThreshold {
 180  				pos = (pos + step) & mask
 181  			}
 182  		}
 183  	}
 184  	if pos != 0 {
 185  		return r.makeError(off, "FSE count error")
 186  	}
 187  
 188  	for i := 0; i < tableSize; i++ {
 189  		sym := table[i].sym
 190  		nextState := next[sym]
 191  		next[sym]++
 192  
 193  		if nextState == 0 {
 194  			return r.makeError(off, "FSE state error")
 195  		}
 196  
 197  		highBit := 15 - bits.LeadingZeros16(nextState)
 198  
 199  		bits := tableBits - highBit
 200  		table[i].bits = uint8(bits)
 201  		table[i].base = (nextState << bits) - uint16(tableSize)
 202  	}
 203  
 204  	return nil
 205  }
 206  
 207  // fseBaselineEntry is an entry in an FSE baseline table.
 208  // We use these for literal/match/length values.
 209  // Those require mapping the symbol to a baseline value,
 210  // and then reading zero or more bits and adding the value to the baseline.
 211  // Rather than looking these up in separate tables,
 212  // we convert the FSE table to an FSE baseline table.
 213  type fseBaselineEntry struct {
 214  	baseline uint32 // baseline for value that this entry represents
 215  	basebits uint8  // number of bits to read to add to baseline
 216  	bits     uint8  // number of bits to read to determine next state
 217  	base     uint16 // add the bits to this base to get the next state
 218  }
 219  
 220  // Given a literal length code, we need to read a number of bits and
 221  // add that to a baseline. For states 0 to 15 the baseline is the
 222  // state and the number of bits is zero. RFC 3.1.1.3.2.1.1.
 223  
 224  const literalLengthOffset = 16
 225  
 226  var literalLengthBase = []uint32{
 227  	16 | (1 << 24),
 228  	18 | (1 << 24),
 229  	20 | (1 << 24),
 230  	22 | (1 << 24),
 231  	24 | (2 << 24),
 232  	28 | (2 << 24),
 233  	32 | (3 << 24),
 234  	40 | (3 << 24),
 235  	48 | (4 << 24),
 236  	64 | (6 << 24),
 237  	128 | (7 << 24),
 238  	256 | (8 << 24),
 239  	512 | (9 << 24),
 240  	1024 | (10 << 24),
 241  	2048 | (11 << 24),
 242  	4096 | (12 << 24),
 243  	8192 | (13 << 24),
 244  	16384 | (14 << 24),
 245  	32768 | (15 << 24),
 246  	65536 | (16 << 24),
 247  }
 248  
 249  // makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
 250  func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
 251  	for i, e := range fseTable {
 252  		be := fseBaselineEntry{
 253  			bits: e.bits,
 254  			base: e.base,
 255  		}
 256  		if e.sym < literalLengthOffset {
 257  			be.baseline = uint32(e.sym)
 258  			be.basebits = 0
 259  		} else {
 260  			if e.sym > 35 {
 261  				return r.makeError(off, "FSE baseline symbol overflow")
 262  			}
 263  			idx := e.sym - literalLengthOffset
 264  			basebits := literalLengthBase[idx]
 265  			be.baseline = basebits & 0xffffff
 266  			be.basebits = uint8(basebits >> 24)
 267  		}
 268  		baselineTable[i] = be
 269  	}
 270  	return nil
 271  }
 272  
 273  // makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
 274  func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
 275  	for i, e := range fseTable {
 276  		be := fseBaselineEntry{
 277  			bits: e.bits,
 278  			base: e.base,
 279  		}
 280  		if e.sym > 31 {
 281  			return r.makeError(off, "FSE offset symbol overflow")
 282  		}
 283  
 284  		// The simple way to write this is
 285  		//     be.baseline = 1 << e.sym
 286  		//     be.basebits = e.sym
 287  		// That would give us an offset value that corresponds to
 288  		// the one described in the RFC. However, for offsets > 3
 289  		// we have to subtract 3. And for offset values 1, 2, 3
 290  		// we use a repeated offset.
 291  		//
 292  		// The baseline is always a power of 2, and is never 0,
 293  		// so for those low values we will see one entry that is
 294  		// baseline 1, basebits 0, and one entry that is baseline 2,
 295  		// basebits 1. All other entries will have baseline >= 4
 296  		// basebits >= 2.
 297  		//
 298  		// So we can check for RFC offset <= 3 by checking for
 299  		// basebits <= 1. That means that we can subtract 3 here
 300  		// and not worry about doing it in the hot loop.
 301  
 302  		be.baseline = 1 << e.sym
 303  		if e.sym >= 2 {
 304  			be.baseline -= 3
 305  		}
 306  		be.basebits = e.sym
 307  		baselineTable[i] = be
 308  	}
 309  	return nil
 310  }
 311  
 312  // Given a match length code, we need to read a number of bits and add
 313  // that to a baseline. For states 0 to 31 the baseline is state+3 and
 314  // the number of bits is zero. RFC 3.1.1.3.2.1.1.
 315  
 316  const matchLengthOffset = 32
 317  
 318  var matchLengthBase = []uint32{
 319  	35 | (1 << 24),
 320  	37 | (1 << 24),
 321  	39 | (1 << 24),
 322  	41 | (1 << 24),
 323  	43 | (2 << 24),
 324  	47 | (2 << 24),
 325  	51 | (3 << 24),
 326  	59 | (3 << 24),
 327  	67 | (4 << 24),
 328  	83 | (4 << 24),
 329  	99 | (5 << 24),
 330  	131 | (7 << 24),
 331  	259 | (8 << 24),
 332  	515 | (9 << 24),
 333  	1027 | (10 << 24),
 334  	2051 | (11 << 24),
 335  	4099 | (12 << 24),
 336  	8195 | (13 << 24),
 337  	16387 | (14 << 24),
 338  	32771 | (15 << 24),
 339  	65539 | (16 << 24),
 340  }
 341  
 342  // makeMatchBaselineFSE converts the match length fseTable to baselineTable.
 343  func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
 344  	for i, e := range fseTable {
 345  		be := fseBaselineEntry{
 346  			bits: e.bits,
 347  			base: e.base,
 348  		}
 349  		if e.sym < matchLengthOffset {
 350  			be.baseline = uint32(e.sym) + 3
 351  			be.basebits = 0
 352  		} else {
 353  			if e.sym > 52 {
 354  				return r.makeError(off, "FSE baseline symbol overflow")
 355  			}
 356  			idx := e.sym - matchLengthOffset
 357  			basebits := matchLengthBase[idx]
 358  			be.baseline = basebits & 0xffffff
 359  			be.basebits = uint8(basebits >> 24)
 360  		}
 361  		baselineTable[i] = be
 362  	}
 363  	return nil
 364  }
 365  
 366  // predefinedLiteralTable is the predefined table to use for literal lengths.
 367  // Generated from table in RFC 3.1.1.3.2.2.1.
 368  // Checked by TestPredefinedTables.
 369  var predefinedLiteralTable = [...]fseBaselineEntry{
 370  	{0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
 371  	{3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
 372  	{7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
 373  	{12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
 374  	{20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
 375  	{32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
 376  	{128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
 377  	{4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
 378  	{2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
 379  	{7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
 380  	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
 381  	{18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
 382  	{32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
 383  	{64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
 384  	{2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
 385  	{2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
 386  	{6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
 387  	{11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
 388  	{18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
 389  	{28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
 390  	{65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
 391  	{8192, 13, 6, 0},
 392  }
 393  
 394  // predefinedOffsetTable is the predefined table to use for offsets.
 395  // Generated from table in RFC 3.1.1.3.2.2.3.
 396  // Checked by TestPredefinedTables.
 397  var predefinedOffsetTable = [...]fseBaselineEntry{
 398  	{1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
 399  	{32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
 400  	{125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
 401  	{8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
 402  	{16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
 403  	{125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
 404  	{4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
 405  	{8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
 406  	{61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
 407  	{268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
 408  	{33554429, 25, 5, 0}, {16777213, 24, 5, 0},
 409  }
 410  
 411  // predefinedMatchTable is the predefined table to use for match lengths.
 412  // Generated from table in RFC 3.1.1.3.2.2.2.
 413  // Checked by TestPredefinedTables.
 414  var predefinedMatchTable = [...]fseBaselineEntry{
 415  	{3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
 416  	{6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
 417  	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
 418  	{19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
 419  	{28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
 420  	{37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
 421  	{59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
 422  	{515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
 423  	{6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
 424  	{10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
 425  	{18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
 426  	{27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
 427  	{35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
 428  	{51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
 429  	{259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
 430  	{5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
 431  	{10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
 432  	{17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
 433  	{26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
 434  	{65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
 435  	{8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
 436  	{1027, 10, 6, 0},
 437  }
 438