zstd.mx raw
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // Package zstd provides a decompressor for zstd streams,
6 // described in RFC 8878. It does not support dictionaries.
7 package zstd
8
9 import (
10 "encoding/binary"
11 "errors"
12 "fmt"
13 "io"
14 )
15
16 // fuzzing is a fuzzer hook set to true when fuzzing.
17 // This is used to reject cases where we don't match zstd.
18 var fuzzing = false
19
20 // Reader implements [io.Reader] to read a zstd compressed stream.
21 type Reader struct {
22 // The underlying Reader.
23 r io.Reader
24
25 // Whether we have read the frame header.
26 // This is of interest when buffer is empty.
27 // If true we expect to see a new block.
28 sawFrameHeader bool
29
30 // Whether the current frame expects a checksum.
31 hasChecksum bool
32
33 // Whether we have read at least one frame.
34 readOneFrame bool
35
36 // True if the frame size is not known.
37 frameSizeUnknown bool
38
39 // The number of uncompressed bytes remaining in the current frame.
40 // If frameSizeUnknown is true, this is not valid.
41 remainingFrameSize uint64
42
43 // The number of bytes read from r up to the start of the current
44 // block, for error reporting.
45 blockOffset int64
46
47 // Buffered decompressed data.
48 buffer []byte
49 // Current read offset in buffer.
50 off int
51
52 // The current repeated offsets.
53 repeatedOffset1 uint32
54 repeatedOffset2 uint32
55 repeatedOffset3 uint32
56
57 // The current Huffman tree used for compressing literals.
58 huffmanTable []uint16
59 huffmanTableBits int
60
61 // The window for back references.
62 window window
63
64 // A buffer available to hold a compressed block.
65 compressedBuf []byte
66
67 // A buffer for literals.
68 literals []byte
69
70 // Sequence decode FSE tables.
71 seqTables [3][]fseBaselineEntry
72 seqTableBits [3]uint8
73
74 // Buffers for sequence decode FSE tables.
75 seqTableBuffers [3][]fseBaselineEntry
76
77 // Scratch space used for small reads, to avoid allocation.
78 scratch [16]byte
79
80 // A scratch table for reading an FSE. Only temporarily valid.
81 fseScratch []fseEntry
82
83 // For checksum computation.
84 checksum xxhash64
85 }
86
87 // NewReader creates a new Reader that decompresses data from the given reader.
88 func NewReader(input io.Reader) *Reader {
89 r := &Reader{}
90 r.Reset(input)
91 return r
92 }
93
94 // Reset discards the current state and starts reading a new stream from r.
95 // This permits reusing a Reader rather than allocating a new one.
96 func (r *Reader) Reset(input io.Reader) {
97 r.r = input
98
99 // Several fields are preserved to avoid allocation.
100 // Others are always set before they are used.
101 r.sawFrameHeader = false
102 r.hasChecksum = false
103 r.readOneFrame = false
104 r.frameSizeUnknown = false
105 r.remainingFrameSize = 0
106 r.blockOffset = 0
107 r.buffer = r.buffer[:0]
108 r.off = 0
109 // repeatedOffset1
110 // repeatedOffset2
111 // repeatedOffset3
112 // huffmanTable
113 // huffmanTableBits
114 // window
115 // compressedBuf
116 // literals
117 // seqTables
118 // seqTableBits
119 // seqTableBuffers
120 // scratch
121 // fseScratch
122 }
123
124 // Read implements [io.Reader].
125 func (r *Reader) Read(p []byte) (int, error) {
126 if err := r.refillIfNeeded(); err != nil {
127 return 0, err
128 }
129 n := copy(p, r.buffer[r.off:])
130 r.off += n
131 return n, nil
132 }
133
134 // ReadByte implements [io.ByteReader].
135 func (r *Reader) ReadByte() (byte, error) {
136 if err := r.refillIfNeeded(); err != nil {
137 return 0, err
138 }
139 ret := r.buffer[r.off]
140 r.off++
141 return ret, nil
142 }
143
144 // refillIfNeeded reads the next block if necessary.
145 func (r *Reader) refillIfNeeded() error {
146 for r.off >= len(r.buffer) {
147 if err := r.refill(); err != nil {
148 return err
149 }
150 r.off = 0
151 }
152 return nil
153 }
154
155 // refill reads and decompresses the next block.
156 func (r *Reader) refill() error {
157 if !r.sawFrameHeader {
158 if err := r.readFrameHeader(); err != nil {
159 return err
160 }
161 }
162 return r.readBlock()
163 }
164
165 // readFrameHeader reads the frame header and prepares to read a block.
166 func (r *Reader) readFrameHeader() error {
167 retry:
168 relativeOffset := 0
169
170 // Read magic number. RFC 3.1.1.
171 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
172 // We require that the stream contains at least one frame.
173 if err == io.EOF && !r.readOneFrame {
174 err = io.ErrUnexpectedEOF
175 }
176 return r.wrapError(relativeOffset, err)
177 }
178
179 if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
180 if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
181 // This is a skippable frame.
182 r.blockOffset += int64(relativeOffset) + 4
183 if err := r.skipFrame(); err != nil {
184 return err
185 }
186 r.readOneFrame = true
187 goto retry
188 }
189
190 return r.makeError(relativeOffset, "invalid magic number")
191 }
192
193 relativeOffset += 4
194
195 // Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
196 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
197 return r.wrapNonEOFError(relativeOffset, err)
198 }
199 descriptor := r.scratch[0]
200
201 singleSegment := descriptor&(1<<5) != 0
202
203 fcsFieldSize := 1 << (descriptor >> 6)
204 if fcsFieldSize == 1 && !singleSegment {
205 fcsFieldSize = 0
206 }
207
208 var windowDescriptorSize int
209 if singleSegment {
210 windowDescriptorSize = 0
211 } else {
212 windowDescriptorSize = 1
213 }
214
215 if descriptor&(1<<3) != 0 {
216 return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
217 }
218
219 r.hasChecksum = descriptor&(1<<2) != 0
220 if r.hasChecksum {
221 r.checksum.reset()
222 }
223
224 // Dictionary_ID_Flag. RFC 3.1.1.1.1.6.
225 dictionaryIdSize := 0
226 if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
227 dictionaryIdSize = 1 << (dictIdFlag - 1)
228 }
229
230 relativeOffset++
231
232 headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
233
234 if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
235 return r.wrapNonEOFError(relativeOffset, err)
236 }
237
238 // Figure out the maximum amount of data we need to retain
239 // for backreferences.
240 var windowSize uint64
241 if !singleSegment {
242 // Window descriptor. RFC 3.1.1.1.2.
243 windowDescriptor := r.scratch[0]
244 exponent := uint64(windowDescriptor >> 3)
245 mantissa := uint64(windowDescriptor & 7)
246 windowLog := exponent + 10
247 windowBase := uint64(1) << windowLog
248 windowAdd := (windowBase / 8) * mantissa
249 windowSize = windowBase + windowAdd
250
251 // Default zstd sets limits on the window size.
252 if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
253 return r.makeError(relativeOffset, "windowSize too large")
254 }
255 }
256
257 // Dictionary_ID. RFC 3.1.1.1.3.
258 if dictionaryIdSize != 0 {
259 dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
260 // Allow only zero Dictionary ID.
261 for _, b := range dictionaryId {
262 if b != 0 {
263 return r.makeError(relativeOffset, "dictionaries are not supported")
264 }
265 }
266 }
267
268 // Frame_Content_Size. RFC 3.1.1.1.4.
269 r.frameSizeUnknown = false
270 r.remainingFrameSize = 0
271 fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
272 switch fcsFieldSize {
273 case 0:
274 r.frameSizeUnknown = true
275 case 1:
276 r.remainingFrameSize = uint64(fb[0])
277 case 2:
278 r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
279 case 4:
280 r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
281 case 8:
282 r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
283 default:
284 panic("unreachable")
285 }
286
287 // RFC 3.1.1.1.2.
288 // When Single_Segment_Flag is set, Window_Descriptor is not present.
289 // In this case, Window_Size is Frame_Content_Size.
290 if singleSegment {
291 windowSize = r.remainingFrameSize
292 }
293
294 // RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size.
295 const maxWindowSize = 8 << 20
296 if windowSize > maxWindowSize {
297 windowSize = maxWindowSize
298 }
299
300 relativeOffset += headerSize
301
302 r.sawFrameHeader = true
303 r.readOneFrame = true
304 r.blockOffset += int64(relativeOffset)
305
306 // Prepare to read blocks from the frame.
307 r.repeatedOffset1 = 1
308 r.repeatedOffset2 = 4
309 r.repeatedOffset3 = 8
310 r.huffmanTableBits = 0
311 r.window.reset(int(windowSize))
312 r.seqTables[0] = nil
313 r.seqTables[1] = nil
314 r.seqTables[2] = nil
315
316 return nil
317 }
318
319 // skipFrame skips a skippable frame. RFC 3.1.2.
320 func (r *Reader) skipFrame() error {
321 relativeOffset := 0
322
323 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
324 return r.wrapNonEOFError(relativeOffset, err)
325 }
326
327 relativeOffset += 4
328
329 size := binary.LittleEndian.Uint32(r.scratch[:4])
330 if size == 0 {
331 r.blockOffset += int64(relativeOffset)
332 return nil
333 }
334
335 if seeker, ok := r.r.(io.Seeker); ok {
336 r.blockOffset += int64(relativeOffset)
337 // Implementations of Seeker do not always detect invalid offsets,
338 // so check that the new offset is valid by comparing to the end.
339 prev, err := seeker.Seek(0, io.SeekCurrent)
340 if err != nil {
341 return r.wrapError(0, err)
342 }
343 end, err := seeker.Seek(0, io.SeekEnd)
344 if err != nil {
345 return r.wrapError(0, err)
346 }
347 if prev > end-int64(size) {
348 r.blockOffset += end - prev
349 return r.makeEOFError(0)
350 }
351
352 // The new offset is valid, so seek to it.
353 _, err = seeker.Seek(prev+int64(size), io.SeekStart)
354 if err != nil {
355 return r.wrapError(0, err)
356 }
357 r.blockOffset += int64(size)
358 return nil
359 }
360
361 n, err := io.CopyN(io.Discard, r.r, int64(size))
362 relativeOffset += int(n)
363 if err != nil {
364 return r.wrapNonEOFError(relativeOffset, err)
365 }
366 r.blockOffset += int64(relativeOffset)
367 return nil
368 }
369
370 // readBlock reads the next block from a frame.
371 func (r *Reader) readBlock() error {
372 relativeOffset := 0
373
374 // Read Block_Header. RFC 3.1.1.2.
375 if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
376 return r.wrapNonEOFError(relativeOffset, err)
377 }
378
379 relativeOffset += 3
380
381 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
382
383 lastBlock := header&1 != 0
384 blockType := (header >> 1) & 3
385 blockSize := int(header >> 3)
386
387 // Maximum block size is smaller of window size and 128K.
388 // We don't record the window size for a single segment frame,
389 // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
390 if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
391 return r.makeError(relativeOffset, "block size too large")
392 }
393
394 // Handle different block types. RFC 3.1.1.2.2.
395 switch blockType {
396 case 0:
397 r.setBufferSize(blockSize)
398 if _, err := io.ReadFull(r.r, r.buffer); err != nil {
399 return r.wrapNonEOFError(relativeOffset, err)
400 }
401 relativeOffset += blockSize
402 r.blockOffset += int64(relativeOffset)
403 case 1:
404 r.setBufferSize(blockSize)
405 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
406 return r.wrapNonEOFError(relativeOffset, err)
407 }
408 relativeOffset++
409 v := r.scratch[0]
410 for i := range r.buffer {
411 r.buffer[i] = v
412 }
413 r.blockOffset += int64(relativeOffset)
414 case 2:
415 r.blockOffset += int64(relativeOffset)
416 if err := r.compressedBlock(blockSize); err != nil {
417 return err
418 }
419 r.blockOffset += int64(blockSize)
420 case 3:
421 return r.makeError(relativeOffset, "invalid block type")
422 }
423
424 if !r.frameSizeUnknown {
425 if uint64(len(r.buffer)) > r.remainingFrameSize {
426 return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
427 }
428 r.remainingFrameSize -= uint64(len(r.buffer))
429 }
430
431 if r.hasChecksum {
432 r.checksum.update(r.buffer)
433 }
434
435 if !lastBlock {
436 r.window.save(r.buffer)
437 } else {
438 if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
439 return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
440 }
441 // Check for checksum at end of frame. RFC 3.1.1.
442 if r.hasChecksum {
443 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
444 return r.wrapNonEOFError(0, err)
445 }
446
447 inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
448 dataChecksum := uint32(r.checksum.digest())
449 if inputChecksum != dataChecksum {
450 return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
451 }
452
453 r.blockOffset += 4
454 }
455 r.sawFrameHeader = false
456 }
457
458 return nil
459 }
460
461 // setBufferSize sets the decompressed buffer size.
462 // When this is called the buffer is empty.
463 func (r *Reader) setBufferSize(size int) {
464 if cap(r.buffer) < size {
465 need := size - cap(r.buffer)
466 r.buffer = append(r.buffer[:cap(r.buffer)], []byte{:need}...)
467 }
468 r.buffer = r.buffer[:size]
469 }
470
471 // zstdError is an error while decompressing.
472 type zstdError struct {
473 offset int64
474 err error
475 }
476
477 func (ze *zstdError) Error() string {
478 return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
479 }
480
481 func (ze *zstdError) Unwrap() error {
482 return ze.err
483 }
484
485 func (r *Reader) makeEOFError(off int) error {
486 return r.wrapError(off, io.ErrUnexpectedEOF)
487 }
488
489 func (r *Reader) wrapNonEOFError(off int, err error) error {
490 if err == io.EOF {
491 err = io.ErrUnexpectedEOF
492 }
493 return r.wrapError(off, err)
494 }
495
496 func (r *Reader) makeError(off int, msg []byte) error {
497 return r.wrapError(off, errors.New(msg))
498 }
499
500 func (r *Reader) wrapError(off int, err error) error {
501 if err == io.EOF {
502 return err
503 }
504 return &zstdError{r.blockOffset + int64(off), err}
505 }
506