enc_best.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"bytes"
   9  	"fmt"
  10  
  11  	"github.com/klauspost/compress"
  12  )
  13  
  14  const (
  15  	bestLongTableBits = 22                     // Bits used in the long match table
  16  	bestLongTableSize = 1 << bestLongTableBits // Size of the table
  17  	bestLongLen       = 8                      // Bytes used for table hash
  18  
  19  	// Note: Increasing the short table bits or making the hash shorter
  20  	// can actually lead to compression degradation since it will 'steal' more from the
  21  	// long match table and match offsets are quite big.
  22  	// This greatly depends on the type of input.
  23  	bestShortTableBits = 18                      // Bits used in the short match table
  24  	bestShortTableSize = 1 << bestShortTableBits // Size of the table
  25  	bestShortLen       = 4                       // Bytes used for table hash
  26  
  27  )
  28  
  29  type match struct {
  30  	offset int32
  31  	s      int32
  32  	length int32
  33  	rep    int32
  34  	est    int32
  35  }
  36  
  37  const highScore = maxMatchLen * 8
  38  
  39  // estBits will estimate output bits from predefined tables.
  40  func (m *match) estBits(bitsPerByte int32) {
  41  	mlc := mlCode(uint32(m.length - zstdMinMatch))
  42  	var ofc uint8
  43  	if m.rep < 0 {
  44  		ofc = ofCode(uint32(m.s-m.offset) + 3)
  45  	} else {
  46  		ofc = ofCode(uint32(m.rep) & 3)
  47  	}
  48  	// Cost, excluding
  49  	ofTT, mlTT := fsePredefEnc[tableOffsets].ct.symbolTT[ofc], fsePredefEnc[tableMatchLengths].ct.symbolTT[mlc]
  50  
  51  	// Add cost of match encoding...
  52  	m.est = int32(ofTT.outBits + mlTT.outBits)
  53  	m.est += int32(ofTT.deltaNbBits>>16 + mlTT.deltaNbBits>>16)
  54  	// Subtract savings compared to literal encoding...
  55  	m.est -= (m.length * bitsPerByte) >> 10
  56  	if m.est > 0 {
  57  		// Unlikely gain..
  58  		m.length = 0
  59  		m.est = highScore
  60  	}
  61  }
  62  
  63  // bestFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
  64  // The long match table contains the previous entry with the same hash,
  65  // effectively making it a "chain" of length 2.
  66  // When we find a long match we choose between the two values and select the longest.
  67  // When we find a short match, after checking the long, we check if we can find a long at n+1
  68  // and that it is longer (lazy matching).
  69  type bestFastEncoder struct {
  70  	fastBase
  71  	table         [bestShortTableSize]prevEntry
  72  	longTable     [bestLongTableSize]prevEntry
  73  	dictTable     []prevEntry
  74  	dictLongTable []prevEntry
  75  }
  76  
  77  // Encode improves compression...
  78  func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) {
  79  	const (
  80  		// Input margin is the number of bytes we read (8)
  81  		// and the maximum we will read ahead (2)
  82  		inputMargin            = 8 + 4
  83  		minNonLiteralBlockSize = 16
  84  	)
  85  
  86  	// Protect against e.cur wraparound.
  87  	for e.cur >= e.bufferReset-int32(len(e.hist)) {
  88  		if len(e.hist) == 0 {
  89  			e.table = [bestShortTableSize]prevEntry{}
  90  			e.longTable = [bestLongTableSize]prevEntry{}
  91  			e.cur = e.maxMatchOff
  92  			break
  93  		}
  94  		// Shift down everything in the table that isn't already too far away.
  95  		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
  96  		for i := range e.table[:] {
  97  			v := e.table[i].offset
  98  			v2 := e.table[i].prev
  99  			if v < minOff {
 100  				v = 0
 101  				v2 = 0
 102  			} else {
 103  				v = v - e.cur + e.maxMatchOff
 104  				if v2 < minOff {
 105  					v2 = 0
 106  				} else {
 107  					v2 = v2 - e.cur + e.maxMatchOff
 108  				}
 109  			}
 110  			e.table[i] = prevEntry{
 111  				offset: v,
 112  				prev:   v2,
 113  			}
 114  		}
 115  		for i := range e.longTable[:] {
 116  			v := e.longTable[i].offset
 117  			v2 := e.longTable[i].prev
 118  			if v < minOff {
 119  				v = 0
 120  				v2 = 0
 121  			} else {
 122  				v = v - e.cur + e.maxMatchOff
 123  				if v2 < minOff {
 124  					v2 = 0
 125  				} else {
 126  					v2 = v2 - e.cur + e.maxMatchOff
 127  				}
 128  			}
 129  			e.longTable[i] = prevEntry{
 130  				offset: v,
 131  				prev:   v2,
 132  			}
 133  		}
 134  		e.cur = e.maxMatchOff
 135  		break
 136  	}
 137  
 138  	// Add block to history
 139  	s := e.addBlock(src)
 140  	blk.size = len(src)
 141  
 142  	// Check RLE first
 143  	if len(src) > zstdMinMatch {
 144  		ml := matchLen(src[1:], src)
 145  		if ml == len(src)-1 {
 146  			blk.literals = append(blk.literals, src[0])
 147  			blk.sequences = append(blk.sequences, seq{litLen: 1, matchLen: uint32(len(src)-1) - zstdMinMatch, offset: 1 + 3})
 148  			return
 149  		}
 150  	}
 151  
 152  	if len(src) < minNonLiteralBlockSize {
 153  		blk.extraLits = len(src)
 154  		blk.literals = blk.literals[:len(src)]
 155  		copy(blk.literals, src)
 156  		return
 157  	}
 158  
 159  	// Use this to estimate literal cost.
 160  	// Scaled by 10 bits.
 161  	bitsPerByte := max(
 162  		// Huffman can never go < 1 bit/byte
 163  		int32((compress.ShannonEntropyBits(src)*1024)/len(src)), 1024)
 164  
 165  	// Override src
 166  	src = e.hist
 167  	sLimit := int32(len(src)) - inputMargin
 168  	const kSearchStrength = 10
 169  
 170  	// nextEmit is where in src the next emitLiteral should start from.
 171  	nextEmit := s
 172  
 173  	// Relative offsets
 174  	offset1 := int32(blk.recentOffsets[0])
 175  	offset2 := int32(blk.recentOffsets[1])
 176  	offset3 := int32(blk.recentOffsets[2])
 177  
 178  	addLiterals := func(s *seq, until int32) {
 179  		if until == nextEmit {
 180  			return
 181  		}
 182  		blk.literals = append(blk.literals, src[nextEmit:until]...)
 183  		s.litLen = uint32(until - nextEmit)
 184  	}
 185  
 186  	if debugEncoder {
 187  		println("recent offsets:", blk.recentOffsets)
 188  	}
 189  
 190  encodeLoop:
 191  	for {
 192  		// We allow the encoder to optionally turn off repeat offsets across blocks
 193  		canRepeat := len(blk.sequences) > 2
 194  
 195  		if debugAsserts && canRepeat && offset1 == 0 {
 196  			panic("offset0 was 0")
 197  		}
 198  
 199  		const goodEnough = 250
 200  
 201  		cv := load6432(src, s)
 202  
 203  		nextHashL := hashLen(cv, bestLongTableBits, bestLongLen)
 204  		nextHashS := hashLen(cv, bestShortTableBits, bestShortLen)
 205  		candidateL := e.longTable[nextHashL]
 206  		candidateS := e.table[nextHashS]
 207  
 208  		// Set m to a match at offset if it looks like that will improve compression.
 209  		improve := func(m *match, offset int32, s int32, first uint32, rep int32) {
 210  			delta := s - offset
 211  			if delta >= e.maxMatchOff || delta <= 0 || load3232(src, offset) != first {
 212  				return
 213  			}
 214  			// Try to quick reject if we already have a long match.
 215  			if m.length > 16 {
 216  				left := len(src) - int(m.s+m.length)
 217  				// If we are too close to the end, keep as is.
 218  				if left <= 0 {
 219  					return
 220  				}
 221  				checkLen := m.length - (s - m.s) - 8
 222  				if left > 2 && checkLen > 4 {
 223  					// Check 4 bytes, 4 bytes from the end of the current match.
 224  					a := load3232(src, offset+checkLen)
 225  					b := load3232(src, s+checkLen)
 226  					if a != b {
 227  						return
 228  					}
 229  				}
 230  			}
 231  			l := 4 + e.matchlen(s+4, offset+4, src)
 232  			if m.rep <= 0 {
 233  				// Extend candidate match backwards as far as possible.
 234  				// Do not extend repeats as we can assume they are optimal
 235  				// and offsets change if s == nextEmit.
 236  				tMin := max(s-e.maxMatchOff, 0)
 237  				for offset > tMin && s > nextEmit && src[offset-1] == src[s-1] && l < maxMatchLength {
 238  					s--
 239  					offset--
 240  					l++
 241  				}
 242  			}
 243  			if debugAsserts {
 244  				if offset >= s {
 245  					panic(fmt.Sprintf("offset: %d - s:%d - rep: %d - cur :%d - max: %d", offset, s, rep, e.cur, e.maxMatchOff))
 246  				}
 247  				if !bytes.Equal(src[s:s+l], src[offset:offset+l]) {
 248  					panic(fmt.Sprintf("second match mismatch: %v != %v, first: %08x", src[s:s+4], src[offset:offset+4], first))
 249  				}
 250  			}
 251  			cand := match{offset: offset, s: s, length: l, rep: rep}
 252  			cand.estBits(bitsPerByte)
 253  			if m.est >= highScore || cand.est-m.est+(cand.s-m.s)*bitsPerByte>>10 < 0 {
 254  				*m = cand
 255  			}
 256  		}
 257  
 258  		best := match{s: s, est: highScore}
 259  		improve(&best, candidateL.offset-e.cur, s, uint32(cv), -1)
 260  		improve(&best, candidateL.prev-e.cur, s, uint32(cv), -1)
 261  		improve(&best, candidateS.offset-e.cur, s, uint32(cv), -1)
 262  		improve(&best, candidateS.prev-e.cur, s, uint32(cv), -1)
 263  
 264  		if canRepeat && best.length < goodEnough {
 265  			if s == nextEmit {
 266  				// Check repeats straight after a match.
 267  				improve(&best, s-offset2, s, uint32(cv), 1|4)
 268  				improve(&best, s-offset3, s, uint32(cv), 2|4)
 269  				if offset1 > 1 {
 270  					improve(&best, s-(offset1-1), s, uint32(cv), 3|4)
 271  				}
 272  			}
 273  
 274  			// If either no match or a non-repeat match, check at + 1
 275  			if best.rep <= 0 {
 276  				cv32 := uint32(cv >> 8)
 277  				spp := s + 1
 278  				improve(&best, spp-offset1, spp, cv32, 1)
 279  				improve(&best, spp-offset2, spp, cv32, 2)
 280  				improve(&best, spp-offset3, spp, cv32, 3)
 281  				if best.rep < 0 {
 282  					cv32 = uint32(cv >> 24)
 283  					spp += 2
 284  					improve(&best, spp-offset1, spp, cv32, 1)
 285  					improve(&best, spp-offset2, spp, cv32, 2)
 286  					improve(&best, spp-offset3, spp, cv32, 3)
 287  				}
 288  			}
 289  		}
 290  		// Load next and check...
 291  		e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset}
 292  		e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset}
 293  		index0 := s + 1
 294  
 295  		// Look far ahead, unless we have a really long match already...
 296  		if best.length < goodEnough {
 297  			// No match found, move forward on input, no need to check forward...
 298  			if best.length < 4 {
 299  				s += 1 + (s-nextEmit)>>(kSearchStrength-1)
 300  				if s >= sLimit {
 301  					break encodeLoop
 302  				}
 303  				continue
 304  			}
 305  
 306  			candidateS = e.table[hashLen(cv>>8, bestShortTableBits, bestShortLen)]
 307  			cv = load6432(src, s+1)
 308  			cv2 := load6432(src, s+2)
 309  			candidateL = e.longTable[hashLen(cv, bestLongTableBits, bestLongLen)]
 310  			candidateL2 := e.longTable[hashLen(cv2, bestLongTableBits, bestLongLen)]
 311  
 312  			// Short at s+1
 313  			improve(&best, candidateS.offset-e.cur, s+1, uint32(cv), -1)
 314  			// Long at s+1, s+2
 315  			improve(&best, candidateL.offset-e.cur, s+1, uint32(cv), -1)
 316  			improve(&best, candidateL.prev-e.cur, s+1, uint32(cv), -1)
 317  			improve(&best, candidateL2.offset-e.cur, s+2, uint32(cv2), -1)
 318  			improve(&best, candidateL2.prev-e.cur, s+2, uint32(cv2), -1)
 319  			if false {
 320  				// Short at s+3.
 321  				// Too often worse...
 322  				improve(&best, e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+3, uint32(cv2>>8), -1)
 323  			}
 324  
 325  			// Start check at a fixed offset to allow for a few mismatches.
 326  			// For this compression level 2 yields the best results.
 327  			// We cannot do this if we have already indexed this position.
 328  			const skipBeginning = 2
 329  			if best.s > s-skipBeginning {
 330  				// See if we can find a better match by checking where the current best ends.
 331  				// Use that offset to see if we can find a better full match.
 332  				if sAt := best.s + best.length; sAt < sLimit {
 333  					nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen)
 334  					candidateEnd := e.longTable[nextHashL]
 335  
 336  					if off := candidateEnd.offset - e.cur - best.length + skipBeginning; off >= 0 {
 337  						improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
 338  						if off := candidateEnd.prev - e.cur - best.length + skipBeginning; off >= 0 {
 339  							improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
 340  						}
 341  					}
 342  				}
 343  			}
 344  		}
 345  
 346  		if debugAsserts {
 347  			if best.offset >= best.s {
 348  				panic(fmt.Sprintf("best.offset > s: %d >= %d", best.offset, best.s))
 349  			}
 350  			if best.s < nextEmit {
 351  				panic(fmt.Sprintf("s %d < nextEmit %d", best.s, nextEmit))
 352  			}
 353  			if best.offset < s-e.maxMatchOff {
 354  				panic(fmt.Sprintf("best.offset < s-e.maxMatchOff: %d < %d", best.offset, s-e.maxMatchOff))
 355  			}
 356  			if !bytes.Equal(src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]) {
 357  				panic(fmt.Sprintf("match mismatch: %v != %v", src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]))
 358  			}
 359  		}
 360  
 361  		// We have a match, we can store the forward value
 362  		s = best.s
 363  		if best.rep > 0 {
 364  			var seq seq
 365  			seq.matchLen = uint32(best.length - zstdMinMatch)
 366  			addLiterals(&seq, best.s)
 367  
 368  			// Repeat. If bit 4 is set, this is a non-lit repeat.
 369  			seq.offset = uint32(best.rep & 3)
 370  			if debugSequences {
 371  				println("repeat sequence", seq, "next s:", best.s, "off:", best.s-best.offset)
 372  			}
 373  			blk.sequences = append(blk.sequences, seq)
 374  
 375  			// Index old s + 1 -> s - 1
 376  			s = best.s + best.length
 377  			nextEmit = s
 378  
 379  			// Index skipped...
 380  			end := min(s, sLimit+4)
 381  			off := index0 + e.cur
 382  			for index0 < end {
 383  				cv0 := load6432(src, index0)
 384  				h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
 385  				h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
 386  				e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
 387  				e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
 388  				off++
 389  				index0++
 390  			}
 391  
 392  			switch best.rep {
 393  			case 2, 4 | 1:
 394  				offset1, offset2 = offset2, offset1
 395  			case 3, 4 | 2:
 396  				offset1, offset2, offset3 = offset3, offset1, offset2
 397  			case 4 | 3:
 398  				offset1, offset2, offset3 = offset1-1, offset1, offset2
 399  			}
 400  			if s >= sLimit {
 401  				if debugEncoder {
 402  					println("repeat ended", s, best.length)
 403  				}
 404  				break encodeLoop
 405  			}
 406  			continue
 407  		}
 408  
 409  		// A 4-byte match has been found. Update recent offsets.
 410  		// We'll later see if more than 4 bytes.
 411  		t := best.offset
 412  		offset1, offset2, offset3 = s-t, offset1, offset2
 413  
 414  		if debugAsserts && s <= t {
 415  			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
 416  		}
 417  
 418  		if debugAsserts && int(offset1) > len(src) {
 419  			panic("invalid offset")
 420  		}
 421  
 422  		// Write our sequence
 423  		var seq seq
 424  		l := best.length
 425  		seq.litLen = uint32(s - nextEmit)
 426  		seq.matchLen = uint32(l - zstdMinMatch)
 427  		if seq.litLen > 0 {
 428  			blk.literals = append(blk.literals, src[nextEmit:s]...)
 429  		}
 430  		seq.offset = uint32(s-t) + 3
 431  		s += l
 432  		if debugSequences {
 433  			println("sequence", seq, "next s:", s)
 434  		}
 435  		blk.sequences = append(blk.sequences, seq)
 436  		nextEmit = s
 437  
 438  		// Index old s + 1 -> s - 1 or sLimit
 439  		end := min(s, sLimit-4)
 440  
 441  		off := index0 + e.cur
 442  		for index0 < end {
 443  			cv0 := load6432(src, index0)
 444  			h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
 445  			h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
 446  			e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
 447  			e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
 448  			index0++
 449  			off++
 450  		}
 451  		if s >= sLimit {
 452  			break encodeLoop
 453  		}
 454  	}
 455  
 456  	if int(nextEmit) < len(src) {
 457  		blk.literals = append(blk.literals, src[nextEmit:]...)
 458  		blk.extraLits = len(src) - int(nextEmit)
 459  	}
 460  	blk.recentOffsets[0] = uint32(offset1)
 461  	blk.recentOffsets[1] = uint32(offset2)
 462  	blk.recentOffsets[2] = uint32(offset3)
 463  	if debugEncoder {
 464  		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
 465  	}
 466  }
 467  
 468  // EncodeNoHist will encode a block with no history and no following blocks.
 469  // Most notable difference is that src will not be copied for history and
 470  // we do not need to check for max match length.
 471  func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
 472  	e.ensureHist(len(src))
 473  	e.Encode(blk, src)
 474  }
 475  
 476  // Reset will reset and set a dictionary if not nil
 477  func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
 478  	e.resetBase(d, singleBlock)
 479  	if d == nil {
 480  		return
 481  	}
 482  	// Init or copy dict table
 483  	if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
 484  		if len(e.dictTable) != len(e.table) {
 485  			e.dictTable = make([]prevEntry, len(e.table))
 486  		}
 487  		end := int32(len(d.content)) - 8 + e.maxMatchOff
 488  		for i := e.maxMatchOff; i < end; i += 4 {
 489  			const hashLog = bestShortTableBits
 490  
 491  			cv := load6432(d.content, i-e.maxMatchOff)
 492  			nextHash := hashLen(cv, hashLog, bestShortLen)      // 0 -> 4
 493  			nextHash1 := hashLen(cv>>8, hashLog, bestShortLen)  // 1 -> 5
 494  			nextHash2 := hashLen(cv>>16, hashLog, bestShortLen) // 2 -> 6
 495  			nextHash3 := hashLen(cv>>24, hashLog, bestShortLen) // 3 -> 7
 496  			e.dictTable[nextHash] = prevEntry{
 497  				prev:   e.dictTable[nextHash].offset,
 498  				offset: i,
 499  			}
 500  			e.dictTable[nextHash1] = prevEntry{
 501  				prev:   e.dictTable[nextHash1].offset,
 502  				offset: i + 1,
 503  			}
 504  			e.dictTable[nextHash2] = prevEntry{
 505  				prev:   e.dictTable[nextHash2].offset,
 506  				offset: i + 2,
 507  			}
 508  			e.dictTable[nextHash3] = prevEntry{
 509  				prev:   e.dictTable[nextHash3].offset,
 510  				offset: i + 3,
 511  			}
 512  		}
 513  		e.lastDictID = d.id
 514  	}
 515  
 516  	// Init or copy dict table
 517  	if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
 518  		if len(e.dictLongTable) != len(e.longTable) {
 519  			e.dictLongTable = make([]prevEntry, len(e.longTable))
 520  		}
 521  		if len(d.content) >= 8 {
 522  			cv := load6432(d.content, 0)
 523  			h := hashLen(cv, bestLongTableBits, bestLongLen)
 524  			e.dictLongTable[h] = prevEntry{
 525  				offset: e.maxMatchOff,
 526  				prev:   e.dictLongTable[h].offset,
 527  			}
 528  
 529  			end := int32(len(d.content)) - 8 + e.maxMatchOff
 530  			off := 8 // First to read
 531  			for i := e.maxMatchOff + 1; i < end; i++ {
 532  				cv = cv>>8 | (uint64(d.content[off]) << 56)
 533  				h := hashLen(cv, bestLongTableBits, bestLongLen)
 534  				e.dictLongTable[h] = prevEntry{
 535  					offset: i,
 536  					prev:   e.dictLongTable[h].offset,
 537  				}
 538  				off++
 539  			}
 540  		}
 541  		e.lastDictID = d.id
 542  	}
 543  	// Reset table to initial state
 544  	copy(e.longTable[:], e.dictLongTable)
 545  
 546  	e.cur = e.maxMatchOff
 547  	// Reset table to initial state
 548  	copy(e.table[:], e.dictTable)
 549  }
 550