enc_best.go raw
1 // Copyright 2019+ Klaus Post. All rights reserved.
2 // License information can be found in the LICENSE file.
3 // Based on work by Yann Collet, released under BSD License.
4
5 package zstd
6
7 import (
8 "bytes"
9 "fmt"
10
11 "github.com/klauspost/compress"
12 )
13
14 const (
15 bestLongTableBits = 22 // Bits used in the long match table
16 bestLongTableSize = 1 << bestLongTableBits // Size of the table
17 bestLongLen = 8 // Bytes used for table hash
18
19 // Note: Increasing the short table bits or making the hash shorter
20 // can actually lead to compression degradation since it will 'steal' more from the
21 // long match table and match offsets are quite big.
22 // This greatly depends on the type of input.
23 bestShortTableBits = 18 // Bits used in the short match table
24 bestShortTableSize = 1 << bestShortTableBits // Size of the table
25 bestShortLen = 4 // Bytes used for table hash
26
27 )
28
29 type match struct {
30 offset int32
31 s int32
32 length int32
33 rep int32
34 est int32
35 }
36
37 const highScore = maxMatchLen * 8
38
39 // estBits will estimate output bits from predefined tables.
40 func (m *match) estBits(bitsPerByte int32) {
41 mlc := mlCode(uint32(m.length - zstdMinMatch))
42 var ofc uint8
43 if m.rep < 0 {
44 ofc = ofCode(uint32(m.s-m.offset) + 3)
45 } else {
46 ofc = ofCode(uint32(m.rep) & 3)
47 }
48 // Cost, excluding
49 ofTT, mlTT := fsePredefEnc[tableOffsets].ct.symbolTT[ofc], fsePredefEnc[tableMatchLengths].ct.symbolTT[mlc]
50
51 // Add cost of match encoding...
52 m.est = int32(ofTT.outBits + mlTT.outBits)
53 m.est += int32(ofTT.deltaNbBits>>16 + mlTT.deltaNbBits>>16)
54 // Subtract savings compared to literal encoding...
55 m.est -= (m.length * bitsPerByte) >> 10
56 if m.est > 0 {
57 // Unlikely gain..
58 m.length = 0
59 m.est = highScore
60 }
61 }
62
63 // bestFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
64 // The long match table contains the previous entry with the same hash,
65 // effectively making it a "chain" of length 2.
66 // When we find a long match we choose between the two values and select the longest.
67 // When we find a short match, after checking the long, we check if we can find a long at n+1
68 // and that it is longer (lazy matching).
69 type bestFastEncoder struct {
70 fastBase
71 table [bestShortTableSize]prevEntry
72 longTable [bestLongTableSize]prevEntry
73 dictTable []prevEntry
74 dictLongTable []prevEntry
75 }
76
77 // Encode improves compression...
78 func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) {
79 const (
80 // Input margin is the number of bytes we read (8)
81 // and the maximum we will read ahead (2)
82 inputMargin = 8 + 4
83 minNonLiteralBlockSize = 16
84 )
85
86 // Protect against e.cur wraparound.
87 for e.cur >= e.bufferReset-int32(len(e.hist)) {
88 if len(e.hist) == 0 {
89 e.table = [bestShortTableSize]prevEntry{}
90 e.longTable = [bestLongTableSize]prevEntry{}
91 e.cur = e.maxMatchOff
92 break
93 }
94 // Shift down everything in the table that isn't already too far away.
95 minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
96 for i := range e.table[:] {
97 v := e.table[i].offset
98 v2 := e.table[i].prev
99 if v < minOff {
100 v = 0
101 v2 = 0
102 } else {
103 v = v - e.cur + e.maxMatchOff
104 if v2 < minOff {
105 v2 = 0
106 } else {
107 v2 = v2 - e.cur + e.maxMatchOff
108 }
109 }
110 e.table[i] = prevEntry{
111 offset: v,
112 prev: v2,
113 }
114 }
115 for i := range e.longTable[:] {
116 v := e.longTable[i].offset
117 v2 := e.longTable[i].prev
118 if v < minOff {
119 v = 0
120 v2 = 0
121 } else {
122 v = v - e.cur + e.maxMatchOff
123 if v2 < minOff {
124 v2 = 0
125 } else {
126 v2 = v2 - e.cur + e.maxMatchOff
127 }
128 }
129 e.longTable[i] = prevEntry{
130 offset: v,
131 prev: v2,
132 }
133 }
134 e.cur = e.maxMatchOff
135 break
136 }
137
138 // Add block to history
139 s := e.addBlock(src)
140 blk.size = len(src)
141
142 // Check RLE first
143 if len(src) > zstdMinMatch {
144 ml := matchLen(src[1:], src)
145 if ml == len(src)-1 {
146 blk.literals = append(blk.literals, src[0])
147 blk.sequences = append(blk.sequences, seq{litLen: 1, matchLen: uint32(len(src)-1) - zstdMinMatch, offset: 1 + 3})
148 return
149 }
150 }
151
152 if len(src) < minNonLiteralBlockSize {
153 blk.extraLits = len(src)
154 blk.literals = blk.literals[:len(src)]
155 copy(blk.literals, src)
156 return
157 }
158
159 // Use this to estimate literal cost.
160 // Scaled by 10 bits.
161 bitsPerByte := max(
162 // Huffman can never go < 1 bit/byte
163 int32((compress.ShannonEntropyBits(src)*1024)/len(src)), 1024)
164
165 // Override src
166 src = e.hist
167 sLimit := int32(len(src)) - inputMargin
168 const kSearchStrength = 10
169
170 // nextEmit is where in src the next emitLiteral should start from.
171 nextEmit := s
172
173 // Relative offsets
174 offset1 := int32(blk.recentOffsets[0])
175 offset2 := int32(blk.recentOffsets[1])
176 offset3 := int32(blk.recentOffsets[2])
177
178 addLiterals := func(s *seq, until int32) {
179 if until == nextEmit {
180 return
181 }
182 blk.literals = append(blk.literals, src[nextEmit:until]...)
183 s.litLen = uint32(until - nextEmit)
184 }
185
186 if debugEncoder {
187 println("recent offsets:", blk.recentOffsets)
188 }
189
190 encodeLoop:
191 for {
192 // We allow the encoder to optionally turn off repeat offsets across blocks
193 canRepeat := len(blk.sequences) > 2
194
195 if debugAsserts && canRepeat && offset1 == 0 {
196 panic("offset0 was 0")
197 }
198
199 const goodEnough = 250
200
201 cv := load6432(src, s)
202
203 nextHashL := hashLen(cv, bestLongTableBits, bestLongLen)
204 nextHashS := hashLen(cv, bestShortTableBits, bestShortLen)
205 candidateL := e.longTable[nextHashL]
206 candidateS := e.table[nextHashS]
207
208 // Set m to a match at offset if it looks like that will improve compression.
209 improve := func(m *match, offset int32, s int32, first uint32, rep int32) {
210 delta := s - offset
211 if delta >= e.maxMatchOff || delta <= 0 || load3232(src, offset) != first {
212 return
213 }
214 // Try to quick reject if we already have a long match.
215 if m.length > 16 {
216 left := len(src) - int(m.s+m.length)
217 // If we are too close to the end, keep as is.
218 if left <= 0 {
219 return
220 }
221 checkLen := m.length - (s - m.s) - 8
222 if left > 2 && checkLen > 4 {
223 // Check 4 bytes, 4 bytes from the end of the current match.
224 a := load3232(src, offset+checkLen)
225 b := load3232(src, s+checkLen)
226 if a != b {
227 return
228 }
229 }
230 }
231 l := 4 + e.matchlen(s+4, offset+4, src)
232 if m.rep <= 0 {
233 // Extend candidate match backwards as far as possible.
234 // Do not extend repeats as we can assume they are optimal
235 // and offsets change if s == nextEmit.
236 tMin := max(s-e.maxMatchOff, 0)
237 for offset > tMin && s > nextEmit && src[offset-1] == src[s-1] && l < maxMatchLength {
238 s--
239 offset--
240 l++
241 }
242 }
243 if debugAsserts {
244 if offset >= s {
245 panic(fmt.Sprintf("offset: %d - s:%d - rep: %d - cur :%d - max: %d", offset, s, rep, e.cur, e.maxMatchOff))
246 }
247 if !bytes.Equal(src[s:s+l], src[offset:offset+l]) {
248 panic(fmt.Sprintf("second match mismatch: %v != %v, first: %08x", src[s:s+4], src[offset:offset+4], first))
249 }
250 }
251 cand := match{offset: offset, s: s, length: l, rep: rep}
252 cand.estBits(bitsPerByte)
253 if m.est >= highScore || cand.est-m.est+(cand.s-m.s)*bitsPerByte>>10 < 0 {
254 *m = cand
255 }
256 }
257
258 best := match{s: s, est: highScore}
259 improve(&best, candidateL.offset-e.cur, s, uint32(cv), -1)
260 improve(&best, candidateL.prev-e.cur, s, uint32(cv), -1)
261 improve(&best, candidateS.offset-e.cur, s, uint32(cv), -1)
262 improve(&best, candidateS.prev-e.cur, s, uint32(cv), -1)
263
264 if canRepeat && best.length < goodEnough {
265 if s == nextEmit {
266 // Check repeats straight after a match.
267 improve(&best, s-offset2, s, uint32(cv), 1|4)
268 improve(&best, s-offset3, s, uint32(cv), 2|4)
269 if offset1 > 1 {
270 improve(&best, s-(offset1-1), s, uint32(cv), 3|4)
271 }
272 }
273
274 // If either no match or a non-repeat match, check at + 1
275 if best.rep <= 0 {
276 cv32 := uint32(cv >> 8)
277 spp := s + 1
278 improve(&best, spp-offset1, spp, cv32, 1)
279 improve(&best, spp-offset2, spp, cv32, 2)
280 improve(&best, spp-offset3, spp, cv32, 3)
281 if best.rep < 0 {
282 cv32 = uint32(cv >> 24)
283 spp += 2
284 improve(&best, spp-offset1, spp, cv32, 1)
285 improve(&best, spp-offset2, spp, cv32, 2)
286 improve(&best, spp-offset3, spp, cv32, 3)
287 }
288 }
289 }
290 // Load next and check...
291 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset}
292 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset}
293 index0 := s + 1
294
295 // Look far ahead, unless we have a really long match already...
296 if best.length < goodEnough {
297 // No match found, move forward on input, no need to check forward...
298 if best.length < 4 {
299 s += 1 + (s-nextEmit)>>(kSearchStrength-1)
300 if s >= sLimit {
301 break encodeLoop
302 }
303 continue
304 }
305
306 candidateS = e.table[hashLen(cv>>8, bestShortTableBits, bestShortLen)]
307 cv = load6432(src, s+1)
308 cv2 := load6432(src, s+2)
309 candidateL = e.longTable[hashLen(cv, bestLongTableBits, bestLongLen)]
310 candidateL2 := e.longTable[hashLen(cv2, bestLongTableBits, bestLongLen)]
311
312 // Short at s+1
313 improve(&best, candidateS.offset-e.cur, s+1, uint32(cv), -1)
314 // Long at s+1, s+2
315 improve(&best, candidateL.offset-e.cur, s+1, uint32(cv), -1)
316 improve(&best, candidateL.prev-e.cur, s+1, uint32(cv), -1)
317 improve(&best, candidateL2.offset-e.cur, s+2, uint32(cv2), -1)
318 improve(&best, candidateL2.prev-e.cur, s+2, uint32(cv2), -1)
319 if false {
320 // Short at s+3.
321 // Too often worse...
322 improve(&best, e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+3, uint32(cv2>>8), -1)
323 }
324
325 // Start check at a fixed offset to allow for a few mismatches.
326 // For this compression level 2 yields the best results.
327 // We cannot do this if we have already indexed this position.
328 const skipBeginning = 2
329 if best.s > s-skipBeginning {
330 // See if we can find a better match by checking where the current best ends.
331 // Use that offset to see if we can find a better full match.
332 if sAt := best.s + best.length; sAt < sLimit {
333 nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen)
334 candidateEnd := e.longTable[nextHashL]
335
336 if off := candidateEnd.offset - e.cur - best.length + skipBeginning; off >= 0 {
337 improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
338 if off := candidateEnd.prev - e.cur - best.length + skipBeginning; off >= 0 {
339 improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
340 }
341 }
342 }
343 }
344 }
345
346 if debugAsserts {
347 if best.offset >= best.s {
348 panic(fmt.Sprintf("best.offset > s: %d >= %d", best.offset, best.s))
349 }
350 if best.s < nextEmit {
351 panic(fmt.Sprintf("s %d < nextEmit %d", best.s, nextEmit))
352 }
353 if best.offset < s-e.maxMatchOff {
354 panic(fmt.Sprintf("best.offset < s-e.maxMatchOff: %d < %d", best.offset, s-e.maxMatchOff))
355 }
356 if !bytes.Equal(src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]) {
357 panic(fmt.Sprintf("match mismatch: %v != %v", src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]))
358 }
359 }
360
361 // We have a match, we can store the forward value
362 s = best.s
363 if best.rep > 0 {
364 var seq seq
365 seq.matchLen = uint32(best.length - zstdMinMatch)
366 addLiterals(&seq, best.s)
367
368 // Repeat. If bit 4 is set, this is a non-lit repeat.
369 seq.offset = uint32(best.rep & 3)
370 if debugSequences {
371 println("repeat sequence", seq, "next s:", best.s, "off:", best.s-best.offset)
372 }
373 blk.sequences = append(blk.sequences, seq)
374
375 // Index old s + 1 -> s - 1
376 s = best.s + best.length
377 nextEmit = s
378
379 // Index skipped...
380 end := min(s, sLimit+4)
381 off := index0 + e.cur
382 for index0 < end {
383 cv0 := load6432(src, index0)
384 h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
385 h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
386 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
387 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
388 off++
389 index0++
390 }
391
392 switch best.rep {
393 case 2, 4 | 1:
394 offset1, offset2 = offset2, offset1
395 case 3, 4 | 2:
396 offset1, offset2, offset3 = offset3, offset1, offset2
397 case 4 | 3:
398 offset1, offset2, offset3 = offset1-1, offset1, offset2
399 }
400 if s >= sLimit {
401 if debugEncoder {
402 println("repeat ended", s, best.length)
403 }
404 break encodeLoop
405 }
406 continue
407 }
408
409 // A 4-byte match has been found. Update recent offsets.
410 // We'll later see if more than 4 bytes.
411 t := best.offset
412 offset1, offset2, offset3 = s-t, offset1, offset2
413
414 if debugAsserts && s <= t {
415 panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
416 }
417
418 if debugAsserts && int(offset1) > len(src) {
419 panic("invalid offset")
420 }
421
422 // Write our sequence
423 var seq seq
424 l := best.length
425 seq.litLen = uint32(s - nextEmit)
426 seq.matchLen = uint32(l - zstdMinMatch)
427 if seq.litLen > 0 {
428 blk.literals = append(blk.literals, src[nextEmit:s]...)
429 }
430 seq.offset = uint32(s-t) + 3
431 s += l
432 if debugSequences {
433 println("sequence", seq, "next s:", s)
434 }
435 blk.sequences = append(blk.sequences, seq)
436 nextEmit = s
437
438 // Index old s + 1 -> s - 1 or sLimit
439 end := min(s, sLimit-4)
440
441 off := index0 + e.cur
442 for index0 < end {
443 cv0 := load6432(src, index0)
444 h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
445 h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
446 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
447 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
448 index0++
449 off++
450 }
451 if s >= sLimit {
452 break encodeLoop
453 }
454 }
455
456 if int(nextEmit) < len(src) {
457 blk.literals = append(blk.literals, src[nextEmit:]...)
458 blk.extraLits = len(src) - int(nextEmit)
459 }
460 blk.recentOffsets[0] = uint32(offset1)
461 blk.recentOffsets[1] = uint32(offset2)
462 blk.recentOffsets[2] = uint32(offset3)
463 if debugEncoder {
464 println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
465 }
466 }
467
468 // EncodeNoHist will encode a block with no history and no following blocks.
469 // Most notable difference is that src will not be copied for history and
470 // we do not need to check for max match length.
471 func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
472 e.ensureHist(len(src))
473 e.Encode(blk, src)
474 }
475
476 // Reset will reset and set a dictionary if not nil
477 func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
478 e.resetBase(d, singleBlock)
479 if d == nil {
480 return
481 }
482 // Init or copy dict table
483 if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
484 if len(e.dictTable) != len(e.table) {
485 e.dictTable = make([]prevEntry, len(e.table))
486 }
487 end := int32(len(d.content)) - 8 + e.maxMatchOff
488 for i := e.maxMatchOff; i < end; i += 4 {
489 const hashLog = bestShortTableBits
490
491 cv := load6432(d.content, i-e.maxMatchOff)
492 nextHash := hashLen(cv, hashLog, bestShortLen) // 0 -> 4
493 nextHash1 := hashLen(cv>>8, hashLog, bestShortLen) // 1 -> 5
494 nextHash2 := hashLen(cv>>16, hashLog, bestShortLen) // 2 -> 6
495 nextHash3 := hashLen(cv>>24, hashLog, bestShortLen) // 3 -> 7
496 e.dictTable[nextHash] = prevEntry{
497 prev: e.dictTable[nextHash].offset,
498 offset: i,
499 }
500 e.dictTable[nextHash1] = prevEntry{
501 prev: e.dictTable[nextHash1].offset,
502 offset: i + 1,
503 }
504 e.dictTable[nextHash2] = prevEntry{
505 prev: e.dictTable[nextHash2].offset,
506 offset: i + 2,
507 }
508 e.dictTable[nextHash3] = prevEntry{
509 prev: e.dictTable[nextHash3].offset,
510 offset: i + 3,
511 }
512 }
513 e.lastDictID = d.id
514 }
515
516 // Init or copy dict table
517 if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
518 if len(e.dictLongTable) != len(e.longTable) {
519 e.dictLongTable = make([]prevEntry, len(e.longTable))
520 }
521 if len(d.content) >= 8 {
522 cv := load6432(d.content, 0)
523 h := hashLen(cv, bestLongTableBits, bestLongLen)
524 e.dictLongTable[h] = prevEntry{
525 offset: e.maxMatchOff,
526 prev: e.dictLongTable[h].offset,
527 }
528
529 end := int32(len(d.content)) - 8 + e.maxMatchOff
530 off := 8 // First to read
531 for i := e.maxMatchOff + 1; i < end; i++ {
532 cv = cv>>8 | (uint64(d.content[off]) << 56)
533 h := hashLen(cv, bestLongTableBits, bestLongLen)
534 e.dictLongTable[h] = prevEntry{
535 offset: i,
536 prev: e.dictLongTable[h].offset,
537 }
538 off++
539 }
540 }
541 e.lastDictID = d.id
542 }
543 // Reset table to initial state
544 copy(e.longTable[:], e.dictLongTable)
545
546 e.cur = e.maxMatchOff
547 // Reset table to initial state
548 copy(e.table[:], e.dictTable)
549 }
550