block.mx raw
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package zstd
6
7 import (
8 "io"
9 )
10
11 // debug can be set in the source to print debug info using println.
12 const debug = false
13
14 // compressedBlock decompresses a compressed block, storing the decompressed
15 // data in r.buffer. The blockSize argument is the compressed size.
16 // RFC 3.1.1.3.
17 func (r *Reader) compressedBlock(blockSize int) error {
18 if len(r.compressedBuf) >= blockSize {
19 r.compressedBuf = r.compressedBuf[:blockSize]
20 } else {
21 // We know that blockSize <= 128K,
22 // so this won't allocate an enormous amount.
23 need := blockSize - len(r.compressedBuf)
24 r.compressedBuf = append(r.compressedBuf, []byte{:need}...)
25 }
26
27 if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
28 return r.wrapNonEOFError(0, err)
29 }
30
31 data := block(r.compressedBuf)
32 off := 0
33 r.buffer = r.buffer[:0]
34
35 litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
36 if err != nil {
37 return err
38 }
39 r.literals = litbuf
40
41 off = litoff
42
43 seqCount, off, err := r.initSeqs(data, off)
44 if err != nil {
45 return err
46 }
47
48 if seqCount == 0 {
49 // No sequences, just literals.
50 if off < len(data) {
51 return r.makeError(off, "extraneous data after no sequences")
52 }
53
54 r.buffer = append(r.buffer, litbuf...)
55
56 return nil
57 }
58
59 return r.execSeqs(data, off, litbuf, seqCount)
60 }
61
62 // seqCode is the kind of sequence codes we have to handle.
63 type seqCode int
64
65 const (
66 seqLiteral seqCode = iota
67 seqOffset
68 seqMatch
69 )
70
71 // seqCodeInfoData is the information needed to set up seqTables and
72 // seqTableBits for a particular kind of sequence code.
73 type seqCodeInfoData struct {
74 predefTable []fseBaselineEntry // predefined FSE
75 predefTableBits int // number of bits in predefTable
76 maxSym int // max symbol value in FSE
77 maxBits int // max bits for FSE
78
79 // toBaseline converts from an FSE table to an FSE baseline table.
80 toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
81 }
82
83 // seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
84 var seqCodeInfo = [3]seqCodeInfoData{
85 seqLiteral: {
86 predefTable: predefinedLiteralTable[:],
87 predefTableBits: 6,
88 maxSym: 35,
89 maxBits: 9,
90 toBaseline: (*Reader).makeLiteralBaselineFSE,
91 },
92 seqOffset: {
93 predefTable: predefinedOffsetTable[:],
94 predefTableBits: 5,
95 maxSym: 31,
96 maxBits: 8,
97 toBaseline: (*Reader).makeOffsetBaselineFSE,
98 },
99 seqMatch: {
100 predefTable: predefinedMatchTable[:],
101 predefTableBits: 6,
102 maxSym: 52,
103 maxBits: 9,
104 toBaseline: (*Reader).makeMatchBaselineFSE,
105 },
106 }
107
108 // initSeqs reads the Sequences_Section_Header and sets up the FSE
109 // tables used to read the sequence codes. It returns the number of
110 // sequences and the new offset. RFC 3.1.1.3.2.1.
111 func (r *Reader) initSeqs(data block, off int) (int, int, error) {
112 if off >= len(data) {
113 return 0, 0, r.makeEOFError(off)
114 }
115
116 seqHdr := data[off]
117 off++
118 if seqHdr == 0 {
119 return 0, off, nil
120 }
121
122 var seqCount int
123 if seqHdr < 128 {
124 seqCount = int(seqHdr)
125 } else if seqHdr < 255 {
126 if off >= len(data) {
127 return 0, 0, r.makeEOFError(off)
128 }
129 seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
130 off++
131 } else {
132 if off+1 >= len(data) {
133 return 0, 0, r.makeEOFError(off)
134 }
135 seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
136 off += 2
137 }
138
139 // Read the Symbol_Compression_Modes byte.
140
141 if off >= len(data) {
142 return 0, 0, r.makeEOFError(off)
143 }
144 symMode := data[off]
145 if symMode&3 != 0 {
146 return 0, 0, r.makeError(off, "invalid symbol compression mode")
147 }
148 off++
149
150 // Set up the FSE tables used to decode the sequence codes.
151
152 var err error
153 off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
154 if err != nil {
155 return 0, 0, err
156 }
157
158 off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
159 if err != nil {
160 return 0, 0, err
161 }
162
163 off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
164 if err != nil {
165 return 0, 0, err
166 }
167
168 return seqCount, off, nil
169 }
170
171 // setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
172 // r.seqTableBits for kind. We store these in the Reader because one of
173 // the modes simply reuses the value from the last block in the frame.
174 func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
175 info := &seqCodeInfo[kind]
176 switch mode {
177 case 0:
178 // Predefined_Mode
179 r.seqTables[kind] = info.predefTable
180 r.seqTableBits[kind] = uint8(info.predefTableBits)
181 return off, nil
182
183 case 1:
184 // RLE_Mode
185 if off >= len(data) {
186 return 0, r.makeEOFError(off)
187 }
188 rle := data[off]
189 off++
190
191 // Build a simple baseline table that always returns rle.
192
193 entry := []fseEntry{
194 {
195 sym: rle,
196 bits: 0,
197 base: 0,
198 },
199 }
200 if cap(r.seqTableBuffers[kind]) == 0 {
201 r.seqTableBuffers[kind] = []fseBaselineEntry{:1<<info.maxBits}
202 }
203 r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
204 if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
205 return 0, err
206 }
207
208 r.seqTables[kind] = r.seqTableBuffers[kind]
209 r.seqTableBits[kind] = 0
210 return off, nil
211
212 case 2:
213 // FSE_Compressed_Mode
214 if cap(r.fseScratch) < 1<<info.maxBits {
215 r.fseScratch = []fseEntry{:1<<info.maxBits}
216 }
217 r.fseScratch = r.fseScratch[:1<<info.maxBits]
218
219 tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
220 if err != nil {
221 return 0, err
222 }
223 r.fseScratch = r.fseScratch[:1<<tableBits]
224
225 if cap(r.seqTableBuffers[kind]) == 0 {
226 r.seqTableBuffers[kind] = []fseBaselineEntry{:1<<info.maxBits}
227 }
228 r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
229
230 if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
231 return 0, err
232 }
233
234 r.seqTables[kind] = r.seqTableBuffers[kind]
235 r.seqTableBits[kind] = uint8(tableBits)
236 return roff, nil
237
238 case 3:
239 // Repeat_Mode
240 if len(r.seqTables[kind]) == 0 {
241 return 0, r.makeError(off, "missing repeat sequence FSE table")
242 }
243 return off, nil
244 }
245 panic("unreachable")
246 }
247
248 // execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
249 func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
250 // Set up the initial states for the sequence code readers.
251
252 rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
253 if err != nil {
254 return err
255 }
256
257 literalState, err := rbr.val(r.seqTableBits[seqLiteral])
258 if err != nil {
259 return err
260 }
261
262 offsetState, err := rbr.val(r.seqTableBits[seqOffset])
263 if err != nil {
264 return err
265 }
266
267 matchState, err := rbr.val(r.seqTableBits[seqMatch])
268 if err != nil {
269 return err
270 }
271
272 // Read and perform all the sequences. RFC 3.1.1.4.
273
274 seq := 0
275 for seq < seqCount {
276 if len(r.buffer)+len(litbuf) > 128<<10 {
277 return rbr.makeError("uncompressed size too big")
278 }
279
280 ptoffset := &r.seqTables[seqOffset][offsetState]
281 ptmatch := &r.seqTables[seqMatch][matchState]
282 ptliteral := &r.seqTables[seqLiteral][literalState]
283
284 add, err := rbr.val(ptoffset.basebits)
285 if err != nil {
286 return err
287 }
288 offset := ptoffset.baseline + add
289
290 add, err = rbr.val(ptmatch.basebits)
291 if err != nil {
292 return err
293 }
294 match := ptmatch.baseline + add
295
296 add, err = rbr.val(ptliteral.basebits)
297 if err != nil {
298 return err
299 }
300 literal := ptliteral.baseline + add
301
302 // Handle repeat offsets. RFC 3.1.1.5.
303 // See the comment in makeOffsetBaselineFSE.
304 if ptoffset.basebits > 1 {
305 r.repeatedOffset3 = r.repeatedOffset2
306 r.repeatedOffset2 = r.repeatedOffset1
307 r.repeatedOffset1 = offset
308 } else {
309 if literal == 0 {
310 offset++
311 }
312 switch offset {
313 case 1:
314 offset = r.repeatedOffset1
315 case 2:
316 offset = r.repeatedOffset2
317 r.repeatedOffset2 = r.repeatedOffset1
318 r.repeatedOffset1 = offset
319 case 3:
320 offset = r.repeatedOffset3
321 r.repeatedOffset3 = r.repeatedOffset2
322 r.repeatedOffset2 = r.repeatedOffset1
323 r.repeatedOffset1 = offset
324 case 4:
325 offset = r.repeatedOffset1 - 1
326 r.repeatedOffset3 = r.repeatedOffset2
327 r.repeatedOffset2 = r.repeatedOffset1
328 r.repeatedOffset1 = offset
329 }
330 }
331
332 seq++
333 if seq < seqCount {
334 // Update the states.
335 add, err = rbr.val(ptliteral.bits)
336 if err != nil {
337 return err
338 }
339 literalState = uint32(ptliteral.base) + add
340
341 add, err = rbr.val(ptmatch.bits)
342 if err != nil {
343 return err
344 }
345 matchState = uint32(ptmatch.base) + add
346
347 add, err = rbr.val(ptoffset.bits)
348 if err != nil {
349 return err
350 }
351 offsetState = uint32(ptoffset.base) + add
352 }
353
354 // The next sequence is now in literal, offset, match.
355
356 if debug {
357 println("literal", literal, "offset", offset, "match", match)
358 }
359
360 // Copy literal bytes from litbuf.
361 if literal > uint32(len(litbuf)) {
362 return rbr.makeError("literal byte overflow")
363 }
364 if literal > 0 {
365 r.buffer = append(r.buffer, litbuf[:literal]...)
366 litbuf = litbuf[literal:]
367 }
368
369 if match > 0 {
370 if err := r.copyFromWindow(&rbr, offset, match); err != nil {
371 return err
372 }
373 }
374 }
375
376 r.buffer = append(r.buffer, litbuf...)
377
378 if rbr.cnt != 0 {
379 return r.makeError(off, "extraneous data after sequences")
380 }
381
382 return nil
383 }
384
385 // Copy match bytes from the decoded output, or the window, at offset.
386 func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
387 if offset == 0 {
388 return rbr.makeError("invalid zero offset")
389 }
390
391 // Offset may point into the buffer or the window and
392 // match may extend past the end of the initial buffer.
393 // |--r.window--|--r.buffer--|
394 // |<-----offset------|
395 // |------match----------->|
396 bufferOffset := uint32(0)
397 lenBlock := uint32(len(r.buffer))
398 if lenBlock < offset {
399 lenWindow := r.window.len()
400 copy := offset - lenBlock
401 if copy > lenWindow {
402 return rbr.makeError("offset past window")
403 }
404 windowOffset := lenWindow - copy
405 if copy > match {
406 copy = match
407 }
408 r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy)
409 match -= copy
410 } else {
411 bufferOffset = lenBlock - offset
412 }
413
414 // We are being asked to copy data that we are adding to the
415 // buffer in the same copy.
416 for match > 0 {
417 copy := uint32(len(r.buffer)) - bufferOffset
418 if copy > match {
419 copy = match
420 }
421 r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...)
422 match -= copy
423 }
424 return nil
425 }
426