fse_encoder.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"errors"
   9  	"fmt"
  10  	"math"
  11  )
  12  
  13  const (
  14  	// For encoding we only support up to
  15  	maxEncTableLog    = 8
  16  	maxEncTablesize   = 1 << maxTableLog
  17  	maxEncTableMask   = (1 << maxTableLog) - 1
  18  	minEncTablelog    = 5
  19  	maxEncSymbolValue = maxMatchLengthSymbol
  20  )
  21  
  22  // Scratch provides temporary storage for compression and decompression.
  23  type fseEncoder struct {
  24  	symbolLen      uint16 // Length of active part of the symbol table.
  25  	actualTableLog uint8  // Selected tablelog.
  26  	ct             cTable // Compression tables.
  27  	maxCount       int    // count of the most probable symbol
  28  	zeroBits       bool   // no bits has prob > 50%.
  29  	clearCount     bool   // clear count
  30  	useRLE         bool   // This encoder is for RLE
  31  	preDefined     bool   // This encoder is predefined.
  32  	reUsed         bool   // Set to know when the encoder has been reused.
  33  	rleVal         uint8  // RLE Symbol
  34  	maxBits        uint8  // Maximum output bits after transform.
  35  
  36  	// TODO: Technically zstd should be fine with 64 bytes.
  37  	count [256]uint32
  38  	norm  [256]int16
  39  }
  40  
  41  // cTable contains tables used for compression.
  42  type cTable struct {
  43  	tableSymbol []byte
  44  	stateTable  []uint16
  45  	symbolTT    []symbolTransform
  46  }
  47  
  48  // symbolTransform contains the state transform for a symbol.
  49  type symbolTransform struct {
  50  	deltaNbBits    uint32
  51  	deltaFindState int16
  52  	outBits        uint8
  53  }
  54  
  55  // String prints values as a human readable string.
  56  func (s symbolTransform) String() string {
  57  	return fmt.Sprintf("{deltabits: %08x, findstate:%d outbits:%d}", s.deltaNbBits, s.deltaFindState, s.outBits)
  58  }
  59  
  60  // Histogram allows to populate the histogram and skip that step in the compression,
  61  // It otherwise allows to inspect the histogram when compression is done.
  62  // To indicate that you have populated the histogram call HistogramFinished
  63  // with the value of the highest populated symbol, as well as the number of entries
  64  // in the most populated entry. These are accepted at face value.
  65  func (s *fseEncoder) Histogram() *[256]uint32 {
  66  	return &s.count
  67  }
  68  
  69  // HistogramFinished can be called to indicate that the histogram has been populated.
  70  // maxSymbol is the index of the highest set symbol of the next data segment.
  71  // maxCount is the number of entries in the most populated entry.
  72  // These are accepted at face value.
  73  func (s *fseEncoder) HistogramFinished(maxSymbol uint8, maxCount int) {
  74  	s.maxCount = maxCount
  75  	s.symbolLen = uint16(maxSymbol) + 1
  76  	s.clearCount = maxCount != 0
  77  }
  78  
  79  // allocCtable will allocate tables needed for compression.
  80  // If existing tables a re big enough, they are simply re-used.
  81  func (s *fseEncoder) allocCtable() {
  82  	tableSize := 1 << s.actualTableLog
  83  	// get tableSymbol that is big enough.
  84  	if cap(s.ct.tableSymbol) < tableSize {
  85  		s.ct.tableSymbol = make([]byte, tableSize)
  86  	}
  87  	s.ct.tableSymbol = s.ct.tableSymbol[:tableSize]
  88  
  89  	ctSize := tableSize
  90  	if cap(s.ct.stateTable) < ctSize {
  91  		s.ct.stateTable = make([]uint16, ctSize)
  92  	}
  93  	s.ct.stateTable = s.ct.stateTable[:ctSize]
  94  
  95  	if cap(s.ct.symbolTT) < 256 {
  96  		s.ct.symbolTT = make([]symbolTransform, 256)
  97  	}
  98  	s.ct.symbolTT = s.ct.symbolTT[:256]
  99  }
 100  
 101  // buildCTable will populate the compression table so it is ready to be used.
 102  func (s *fseEncoder) buildCTable() error {
 103  	tableSize := uint32(1 << s.actualTableLog)
 104  	highThreshold := tableSize - 1
 105  	var cumul [256]int16
 106  
 107  	s.allocCtable()
 108  	tableSymbol := s.ct.tableSymbol[:tableSize]
 109  	// symbol start positions
 110  	{
 111  		cumul[0] = 0
 112  		for ui, v := range s.norm[:s.symbolLen-1] {
 113  			u := byte(ui) // one less than reference
 114  			if v == -1 {
 115  				// Low proba symbol
 116  				cumul[u+1] = cumul[u] + 1
 117  				tableSymbol[highThreshold] = u
 118  				highThreshold--
 119  			} else {
 120  				cumul[u+1] = cumul[u] + v
 121  			}
 122  		}
 123  		// Encode last symbol separately to avoid overflowing u
 124  		u := int(s.symbolLen - 1)
 125  		v := s.norm[s.symbolLen-1]
 126  		if v == -1 {
 127  			// Low proba symbol
 128  			cumul[u+1] = cumul[u] + 1
 129  			tableSymbol[highThreshold] = byte(u)
 130  			highThreshold--
 131  		} else {
 132  			cumul[u+1] = cumul[u] + v
 133  		}
 134  		if uint32(cumul[s.symbolLen]) != tableSize {
 135  			return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize)
 136  		}
 137  		cumul[s.symbolLen] = int16(tableSize) + 1
 138  	}
 139  	// Spread symbols
 140  	s.zeroBits = false
 141  	{
 142  		step := tableStep(tableSize)
 143  		tableMask := tableSize - 1
 144  		var position uint32
 145  		// if any symbol > largeLimit, we may have 0 bits output.
 146  		largeLimit := int16(1 << (s.actualTableLog - 1))
 147  		for ui, v := range s.norm[:s.symbolLen] {
 148  			symbol := byte(ui)
 149  			if v > largeLimit {
 150  				s.zeroBits = true
 151  			}
 152  			for range v {
 153  				tableSymbol[position] = symbol
 154  				position = (position + step) & tableMask
 155  				for position > highThreshold {
 156  					position = (position + step) & tableMask
 157  				} /* Low proba area */
 158  			}
 159  		}
 160  
 161  		// Check if we have gone through all positions
 162  		if position != 0 {
 163  			return errors.New("position!=0")
 164  		}
 165  	}
 166  
 167  	// Build table
 168  	table := s.ct.stateTable
 169  	{
 170  		tsi := int(tableSize)
 171  		for u, v := range tableSymbol {
 172  			// TableU16 : sorted by symbol order; gives next state value
 173  			table[cumul[v]] = uint16(tsi + u)
 174  			cumul[v]++
 175  		}
 176  	}
 177  
 178  	// Build Symbol Transformation Table
 179  	{
 180  		total := int16(0)
 181  		symbolTT := s.ct.symbolTT[:s.symbolLen]
 182  		tableLog := s.actualTableLog
 183  		tl := (uint32(tableLog) << 16) - (1 << tableLog)
 184  		for i, v := range s.norm[:s.symbolLen] {
 185  			switch v {
 186  			case 0:
 187  			case -1, 1:
 188  				symbolTT[i].deltaNbBits = tl
 189  				symbolTT[i].deltaFindState = total - 1
 190  				total++
 191  			default:
 192  				maxBitsOut := uint32(tableLog) - highBit(uint32(v-1))
 193  				minStatePlus := uint32(v) << maxBitsOut
 194  				symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus
 195  				symbolTT[i].deltaFindState = total - v
 196  				total += v
 197  			}
 198  		}
 199  		if total != int16(tableSize) {
 200  			return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize)
 201  		}
 202  	}
 203  	return nil
 204  }
 205  
 206  var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}
 207  
 208  func (s *fseEncoder) setRLE(val byte) {
 209  	s.allocCtable()
 210  	s.actualTableLog = 0
 211  	s.ct.stateTable = s.ct.stateTable[:1]
 212  	s.ct.symbolTT[val] = symbolTransform{
 213  		deltaFindState: 0,
 214  		deltaNbBits:    0,
 215  	}
 216  	if debugEncoder {
 217  		println("setRLE: val", val, "symbolTT", s.ct.symbolTT[val])
 218  	}
 219  	s.rleVal = val
 220  	s.useRLE = true
 221  }
 222  
 223  // setBits will set output bits for the transform.
 224  // if nil is provided, the number of bits is equal to the index.
 225  func (s *fseEncoder) setBits(transform []byte) {
 226  	if s.reUsed || s.preDefined {
 227  		return
 228  	}
 229  	if s.useRLE {
 230  		if transform == nil {
 231  			s.ct.symbolTT[s.rleVal].outBits = s.rleVal
 232  			s.maxBits = s.rleVal
 233  			return
 234  		}
 235  		s.maxBits = transform[s.rleVal]
 236  		s.ct.symbolTT[s.rleVal].outBits = s.maxBits
 237  		return
 238  	}
 239  	if transform == nil {
 240  		for i := range s.ct.symbolTT[:s.symbolLen] {
 241  			s.ct.symbolTT[i].outBits = uint8(i)
 242  		}
 243  		s.maxBits = uint8(s.symbolLen - 1)
 244  		return
 245  	}
 246  	s.maxBits = 0
 247  	for i, v := range transform[:s.symbolLen] {
 248  		s.ct.symbolTT[i].outBits = v
 249  		if v > s.maxBits {
 250  			// We could assume bits always going up, but we play safe.
 251  			s.maxBits = v
 252  		}
 253  	}
 254  }
 255  
 256  // normalizeCount will normalize the count of the symbols so
 257  // the total is equal to the table size.
 258  // If successful, compression tables will also be made ready.
 259  func (s *fseEncoder) normalizeCount(length int) error {
 260  	if s.reUsed {
 261  		return nil
 262  	}
 263  	s.optimalTableLog(length)
 264  	var (
 265  		tableLog          = s.actualTableLog
 266  		scale             = 62 - uint64(tableLog)
 267  		step              = (1 << 62) / uint64(length)
 268  		vStep             = uint64(1) << (scale - 20)
 269  		stillToDistribute = int16(1 << tableLog)
 270  		largest           int
 271  		largestP          int16
 272  		lowThreshold      = (uint32)(length >> tableLog)
 273  	)
 274  	if s.maxCount == length {
 275  		s.useRLE = true
 276  		return nil
 277  	}
 278  	s.useRLE = false
 279  	for i, cnt := range s.count[:s.symbolLen] {
 280  		// already handled
 281  		// if (count[s] == s.length) return 0;   /* rle special case */
 282  
 283  		if cnt == 0 {
 284  			s.norm[i] = 0
 285  			continue
 286  		}
 287  		if cnt <= lowThreshold {
 288  			s.norm[i] = -1
 289  			stillToDistribute--
 290  		} else {
 291  			proba := (int16)((uint64(cnt) * step) >> scale)
 292  			if proba < 8 {
 293  				restToBeat := vStep * uint64(rtbTable[proba])
 294  				v := uint64(cnt)*step - (uint64(proba) << scale)
 295  				if v > restToBeat {
 296  					proba++
 297  				}
 298  			}
 299  			if proba > largestP {
 300  				largestP = proba
 301  				largest = i
 302  			}
 303  			s.norm[i] = proba
 304  			stillToDistribute -= proba
 305  		}
 306  	}
 307  
 308  	if -stillToDistribute >= (s.norm[largest] >> 1) {
 309  		// corner case, need another normalization method
 310  		err := s.normalizeCount2(length)
 311  		if err != nil {
 312  			return err
 313  		}
 314  		if debugAsserts {
 315  			err = s.validateNorm()
 316  			if err != nil {
 317  				return err
 318  			}
 319  		}
 320  		return s.buildCTable()
 321  	}
 322  	s.norm[largest] += stillToDistribute
 323  	if debugAsserts {
 324  		err := s.validateNorm()
 325  		if err != nil {
 326  			return err
 327  		}
 328  	}
 329  	return s.buildCTable()
 330  }
 331  
 332  // Secondary normalization method.
 333  // To be used when primary method fails.
 334  func (s *fseEncoder) normalizeCount2(length int) error {
 335  	const notYetAssigned = -2
 336  	var (
 337  		distributed  uint32
 338  		total        = uint32(length)
 339  		tableLog     = s.actualTableLog
 340  		lowThreshold = total >> tableLog
 341  		lowOne       = (total * 3) >> (tableLog + 1)
 342  	)
 343  	for i, cnt := range s.count[:s.symbolLen] {
 344  		if cnt == 0 {
 345  			s.norm[i] = 0
 346  			continue
 347  		}
 348  		if cnt <= lowThreshold {
 349  			s.norm[i] = -1
 350  			distributed++
 351  			total -= cnt
 352  			continue
 353  		}
 354  		if cnt <= lowOne {
 355  			s.norm[i] = 1
 356  			distributed++
 357  			total -= cnt
 358  			continue
 359  		}
 360  		s.norm[i] = notYetAssigned
 361  	}
 362  	toDistribute := (1 << tableLog) - distributed
 363  
 364  	if (total / toDistribute) > lowOne {
 365  		// risk of rounding to zero
 366  		lowOne = (total * 3) / (toDistribute * 2)
 367  		for i, cnt := range s.count[:s.symbolLen] {
 368  			if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) {
 369  				s.norm[i] = 1
 370  				distributed++
 371  				total -= cnt
 372  				continue
 373  			}
 374  		}
 375  		toDistribute = (1 << tableLog) - distributed
 376  	}
 377  	if distributed == uint32(s.symbolLen)+1 {
 378  		// all values are pretty poor;
 379  		//   probably incompressible data (should have already been detected);
 380  		//   find max, then give all remaining points to max
 381  		var maxV int
 382  		var maxC uint32
 383  		for i, cnt := range s.count[:s.symbolLen] {
 384  			if cnt > maxC {
 385  				maxV = i
 386  				maxC = cnt
 387  			}
 388  		}
 389  		s.norm[maxV] += int16(toDistribute)
 390  		return nil
 391  	}
 392  
 393  	if total == 0 {
 394  		// all of the symbols were low enough for the lowOne or lowThreshold
 395  		for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) {
 396  			if s.norm[i] > 0 {
 397  				toDistribute--
 398  				s.norm[i]++
 399  			}
 400  		}
 401  		return nil
 402  	}
 403  
 404  	var (
 405  		vStepLog = 62 - uint64(tableLog)
 406  		mid      = uint64((1 << (vStepLog - 1)) - 1)
 407  		rStep    = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining
 408  		tmpTotal = mid
 409  	)
 410  	for i, cnt := range s.count[:s.symbolLen] {
 411  		if s.norm[i] == notYetAssigned {
 412  			var (
 413  				end    = tmpTotal + uint64(cnt)*rStep
 414  				sStart = uint32(tmpTotal >> vStepLog)
 415  				sEnd   = uint32(end >> vStepLog)
 416  				weight = sEnd - sStart
 417  			)
 418  			if weight < 1 {
 419  				return errors.New("weight < 1")
 420  			}
 421  			s.norm[i] = int16(weight)
 422  			tmpTotal = end
 423  		}
 424  	}
 425  	return nil
 426  }
 427  
 428  // optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
 429  func (s *fseEncoder) optimalTableLog(length int) {
 430  	tableLog := uint8(maxEncTableLog)
 431  	minBitsSrc := highBit(uint32(length)) + 1
 432  	minBitsSymbols := highBit(uint32(s.symbolLen-1)) + 2
 433  	minBits := uint8(minBitsSymbols)
 434  	if minBitsSrc < minBitsSymbols {
 435  		minBits = uint8(minBitsSrc)
 436  	}
 437  
 438  	maxBitsSrc := uint8(highBit(uint32(length-1))) - 2
 439  	if maxBitsSrc < tableLog {
 440  		// Accuracy can be reduced
 441  		tableLog = maxBitsSrc
 442  	}
 443  	if minBits > tableLog {
 444  		tableLog = minBits
 445  	}
 446  	// Need a minimum to safely represent all symbol values
 447  	if tableLog < minEncTablelog {
 448  		tableLog = minEncTablelog
 449  	}
 450  	if tableLog > maxEncTableLog {
 451  		tableLog = maxEncTableLog
 452  	}
 453  	s.actualTableLog = tableLog
 454  }
 455  
 456  // validateNorm validates the normalized histogram table.
 457  func (s *fseEncoder) validateNorm() (err error) {
 458  	var total int
 459  	for _, v := range s.norm[:s.symbolLen] {
 460  		if v >= 0 {
 461  			total += int(v)
 462  		} else {
 463  			total -= int(v)
 464  		}
 465  	}
 466  	defer func() {
 467  		if err == nil {
 468  			return
 469  		}
 470  		fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen)
 471  		for i, v := range s.norm[:s.symbolLen] {
 472  			fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v)
 473  		}
 474  	}()
 475  	if total != (1 << s.actualTableLog) {
 476  		return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog)
 477  	}
 478  	for i, v := range s.count[s.symbolLen:] {
 479  		if v != 0 {
 480  			return fmt.Errorf("warning: Found symbol out of range, %d after cut", i)
 481  		}
 482  	}
 483  	return nil
 484  }
 485  
 486  // writeCount will write the normalized histogram count to header.
 487  // This is read back by readNCount.
 488  func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
 489  	if s.useRLE {
 490  		return append(out, s.rleVal), nil
 491  	}
 492  	if s.preDefined || s.reUsed {
 493  		// Never write predefined.
 494  		return out, nil
 495  	}
 496  
 497  	var (
 498  		tableLog  = s.actualTableLog
 499  		tableSize = 1 << tableLog
 500  		previous0 bool
 501  		charnum   uint16
 502  
 503  		// maximum header size plus 2 extra bytes for final output if bitCount == 0.
 504  		maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3 + 2
 505  
 506  		// Write Table Size
 507  		bitStream = uint32(tableLog - minEncTablelog)
 508  		bitCount  = uint(4)
 509  		remaining = int16(tableSize + 1) /* +1 for extra accuracy */
 510  		threshold = int16(tableSize)
 511  		nbBits    = uint(tableLog + 1)
 512  		outP      = len(out)
 513  	)
 514  	if cap(out) < outP+maxHeaderSize {
 515  		out = append(out, make([]byte, maxHeaderSize*3)...)
 516  		out = out[:len(out)-maxHeaderSize*3]
 517  	}
 518  	out = out[:outP+maxHeaderSize]
 519  
 520  	// stops at 1
 521  	for remaining > 1 {
 522  		if previous0 {
 523  			start := charnum
 524  			for s.norm[charnum] == 0 {
 525  				charnum++
 526  			}
 527  			for charnum >= start+24 {
 528  				start += 24
 529  				bitStream += uint32(0xFFFF) << bitCount
 530  				out[outP] = byte(bitStream)
 531  				out[outP+1] = byte(bitStream >> 8)
 532  				outP += 2
 533  				bitStream >>= 16
 534  			}
 535  			for charnum >= start+3 {
 536  				start += 3
 537  				bitStream += 3 << bitCount
 538  				bitCount += 2
 539  			}
 540  			bitStream += uint32(charnum-start) << bitCount
 541  			bitCount += 2
 542  			if bitCount > 16 {
 543  				out[outP] = byte(bitStream)
 544  				out[outP+1] = byte(bitStream >> 8)
 545  				outP += 2
 546  				bitStream >>= 16
 547  				bitCount -= 16
 548  			}
 549  		}
 550  
 551  		count := s.norm[charnum]
 552  		charnum++
 553  		max := (2*threshold - 1) - remaining
 554  		if count < 0 {
 555  			remaining += count
 556  		} else {
 557  			remaining -= count
 558  		}
 559  		count++ // +1 for extra accuracy
 560  		if count >= threshold {
 561  			count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[
 562  		}
 563  		bitStream += uint32(count) << bitCount
 564  		bitCount += nbBits
 565  		if count < max {
 566  			bitCount--
 567  		}
 568  
 569  		previous0 = count == 1
 570  		if remaining < 1 {
 571  			return nil, errors.New("internal error: remaining < 1")
 572  		}
 573  		for remaining < threshold {
 574  			nbBits--
 575  			threshold >>= 1
 576  		}
 577  
 578  		if bitCount > 16 {
 579  			out[outP] = byte(bitStream)
 580  			out[outP+1] = byte(bitStream >> 8)
 581  			outP += 2
 582  			bitStream >>= 16
 583  			bitCount -= 16
 584  		}
 585  	}
 586  
 587  	if outP+2 > len(out) {
 588  		return nil, fmt.Errorf("internal error: %d > %d, maxheader: %d, sl: %d, tl: %d, normcount: %v", outP+2, len(out), maxHeaderSize, s.symbolLen, int(tableLog), s.norm[:s.symbolLen])
 589  	}
 590  	out[outP] = byte(bitStream)
 591  	out[outP+1] = byte(bitStream >> 8)
 592  	outP += int((bitCount + 7) / 8)
 593  
 594  	if charnum > s.symbolLen {
 595  		return nil, errors.New("internal error: charnum > s.symbolLen")
 596  	}
 597  	return out[:outP], nil
 598  }
 599  
 600  // Approximate symbol cost, as fractional value, using fixed-point format (accuracyLog fractional bits)
 601  // note 1 : assume symbolValue is valid (<= maxSymbolValue)
 602  // note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits *
 603  func (s *fseEncoder) bitCost(symbolValue uint8, accuracyLog uint32) uint32 {
 604  	minNbBits := s.ct.symbolTT[symbolValue].deltaNbBits >> 16
 605  	threshold := (minNbBits + 1) << 16
 606  	if debugAsserts {
 607  		if !(s.actualTableLog < 16) {
 608  			panic("!s.actualTableLog < 16")
 609  		}
 610  		// ensure enough room for renormalization double shift
 611  		if !(uint8(accuracyLog) < 31-s.actualTableLog) {
 612  			panic("!uint8(accuracyLog) < 31-s.actualTableLog")
 613  		}
 614  	}
 615  	tableSize := uint32(1) << s.actualTableLog
 616  	deltaFromThreshold := threshold - (s.ct.symbolTT[symbolValue].deltaNbBits + tableSize)
 617  	// linear interpolation (very approximate)
 618  	normalizedDeltaFromThreshold := (deltaFromThreshold << accuracyLog) >> s.actualTableLog
 619  	bitMultiplier := uint32(1) << accuracyLog
 620  	if debugAsserts {
 621  		if s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold {
 622  			panic("s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold")
 623  		}
 624  		if normalizedDeltaFromThreshold > bitMultiplier {
 625  			panic("normalizedDeltaFromThreshold > bitMultiplier")
 626  		}
 627  	}
 628  	return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold
 629  }
 630  
 631  // Returns the cost in bits of encoding the distribution in count using ctable.
 632  // Histogram should only be up to the last non-zero symbol.
 633  // Returns an -1 if ctable cannot represent all the symbols in count.
 634  func (s *fseEncoder) approxSize(hist []uint32) uint32 {
 635  	if int(s.symbolLen) < len(hist) {
 636  		// More symbols than we have.
 637  		return math.MaxUint32
 638  	}
 639  	if s.useRLE {
 640  		// We will never reuse RLE encoders.
 641  		return math.MaxUint32
 642  	}
 643  	const kAccuracyLog = 8
 644  	badCost := (uint32(s.actualTableLog) + 1) << kAccuracyLog
 645  	var cost uint32
 646  	for i, v := range hist {
 647  		if v == 0 {
 648  			continue
 649  		}
 650  		if s.norm[i] == 0 {
 651  			return math.MaxUint32
 652  		}
 653  		bitCost := s.bitCost(uint8(i), kAccuracyLog)
 654  		if bitCost > badCost {
 655  			return math.MaxUint32
 656  		}
 657  		cost += v * bitCost
 658  	}
 659  	return cost >> kAccuracyLog
 660  }
 661  
 662  // maxHeaderSize returns the maximum header size in bits.
 663  // This is not exact size, but we want a penalty for new tables anyway.
 664  func (s *fseEncoder) maxHeaderSize() uint32 {
 665  	if s.preDefined {
 666  		return 0
 667  	}
 668  	if s.useRLE {
 669  		return 8
 670  	}
 671  	return (((uint32(s.symbolLen) * uint32(s.actualTableLog)) >> 3) + 3) * 8
 672  }
 673  
 674  // cState contains the compression state of a stream.
 675  type cState struct {
 676  	bw         *bitWriter
 677  	stateTable []uint16
 678  	state      uint16
 679  }
 680  
 681  // init will initialize the compression state to the first symbol of the stream.
 682  func (c *cState) init(bw *bitWriter, ct *cTable, first symbolTransform) {
 683  	c.bw = bw
 684  	c.stateTable = ct.stateTable
 685  	if len(c.stateTable) == 1 {
 686  		// RLE
 687  		c.stateTable[0] = uint16(0)
 688  		c.state = 0
 689  		return
 690  	}
 691  	nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16
 692  	im := int32((nbBitsOut << 16) - first.deltaNbBits)
 693  	lu := (im >> nbBitsOut) + int32(first.deltaFindState)
 694  	c.state = c.stateTable[lu]
 695  }
 696  
 697  // flush will write the tablelog to the output and flush the remaining full bytes.
 698  func (c *cState) flush(tableLog uint8) {
 699  	c.bw.flush32()
 700  	c.bw.addBits16NC(c.state, tableLog)
 701  }
 702