zstd.mx raw

   1  // Copyright 2023 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Package zstd provides a decompressor for zstd streams,
   6  // described in RFC 8878. It does not support dictionaries.
   7  package zstd
   8  
   9  import (
  10  	"encoding/binary"
  11  	"errors"
  12  	"fmt"
  13  	"io"
  14  )
  15  
  16  // fuzzing is a fuzzer hook set to true when fuzzing.
  17  // This is used to reject cases where we don't match zstd.
  18  var fuzzing = false
  19  
  20  // Reader implements [io.Reader] to read a zstd compressed stream.
  21  type Reader struct {
  22  	// The underlying Reader.
  23  	r io.Reader
  24  
  25  	// Whether we have read the frame header.
  26  	// This is of interest when buffer is empty.
  27  	// If true we expect to see a new block.
  28  	sawFrameHeader bool
  29  
  30  	// Whether the current frame expects a checksum.
  31  	hasChecksum bool
  32  
  33  	// Whether we have read at least one frame.
  34  	readOneFrame bool
  35  
  36  	// True if the frame size is not known.
  37  	frameSizeUnknown bool
  38  
  39  	// The number of uncompressed bytes remaining in the current frame.
  40  	// If frameSizeUnknown is true, this is not valid.
  41  	remainingFrameSize uint64
  42  
  43  	// The number of bytes read from r up to the start of the current
  44  	// block, for error reporting.
  45  	blockOffset int64
  46  
  47  	// Buffered decompressed data.
  48  	buffer []byte
  49  	// Current read offset in buffer.
  50  	off int
  51  
  52  	// The current repeated offsets.
  53  	repeatedOffset1 uint32
  54  	repeatedOffset2 uint32
  55  	repeatedOffset3 uint32
  56  
  57  	// The current Huffman tree used for compressing literals.
  58  	huffmanTable     []uint16
  59  	huffmanTableBits int
  60  
  61  	// The window for back references.
  62  	window window
  63  
  64  	// A buffer available to hold a compressed block.
  65  	compressedBuf []byte
  66  
  67  	// A buffer for literals.
  68  	literals []byte
  69  
  70  	// Sequence decode FSE tables.
  71  	seqTables    [3][]fseBaselineEntry
  72  	seqTableBits [3]uint8
  73  
  74  	// Buffers for sequence decode FSE tables.
  75  	seqTableBuffers [3][]fseBaselineEntry
  76  
  77  	// Scratch space used for small reads, to avoid allocation.
  78  	scratch [16]byte
  79  
  80  	// A scratch table for reading an FSE. Only temporarily valid.
  81  	fseScratch []fseEntry
  82  
  83  	// For checksum computation.
  84  	checksum xxhash64
  85  }
  86  
  87  // NewReader creates a new Reader that decompresses data from the given reader.
  88  func NewReader(input io.Reader) *Reader {
  89  	r := &Reader{}
  90  	r.Reset(input)
  91  	return r
  92  }
  93  
  94  // Reset discards the current state and starts reading a new stream from r.
  95  // This permits reusing a Reader rather than allocating a new one.
  96  func (r *Reader) Reset(input io.Reader) {
  97  	r.r = input
  98  
  99  	// Several fields are preserved to avoid allocation.
 100  	// Others are always set before they are used.
 101  	r.sawFrameHeader = false
 102  	r.hasChecksum = false
 103  	r.readOneFrame = false
 104  	r.frameSizeUnknown = false
 105  	r.remainingFrameSize = 0
 106  	r.blockOffset = 0
 107  	r.buffer = r.buffer[:0]
 108  	r.off = 0
 109  	// repeatedOffset1
 110  	// repeatedOffset2
 111  	// repeatedOffset3
 112  	// huffmanTable
 113  	// huffmanTableBits
 114  	// window
 115  	// compressedBuf
 116  	// literals
 117  	// seqTables
 118  	// seqTableBits
 119  	// seqTableBuffers
 120  	// scratch
 121  	// fseScratch
 122  }
 123  
 124  // Read implements [io.Reader].
 125  func (r *Reader) Read(p []byte) (int, error) {
 126  	if err := r.refillIfNeeded(); err != nil {
 127  		return 0, err
 128  	}
 129  	n := copy(p, r.buffer[r.off:])
 130  	r.off += n
 131  	return n, nil
 132  }
 133  
 134  // ReadByte implements [io.ByteReader].
 135  func (r *Reader) ReadByte() (byte, error) {
 136  	if err := r.refillIfNeeded(); err != nil {
 137  		return 0, err
 138  	}
 139  	ret := r.buffer[r.off]
 140  	r.off++
 141  	return ret, nil
 142  }
 143  
 144  // refillIfNeeded reads the next block if necessary.
 145  func (r *Reader) refillIfNeeded() error {
 146  	for r.off >= len(r.buffer) {
 147  		if err := r.refill(); err != nil {
 148  			return err
 149  		}
 150  		r.off = 0
 151  	}
 152  	return nil
 153  }
 154  
 155  // refill reads and decompresses the next block.
 156  func (r *Reader) refill() error {
 157  	if !r.sawFrameHeader {
 158  		if err := r.readFrameHeader(); err != nil {
 159  			return err
 160  		}
 161  	}
 162  	return r.readBlock()
 163  }
 164  
 165  // readFrameHeader reads the frame header and prepares to read a block.
 166  func (r *Reader) readFrameHeader() error {
 167  retry:
 168  	relativeOffset := 0
 169  
 170  	// Read magic number. RFC 3.1.1.
 171  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
 172  		// We require that the stream contains at least one frame.
 173  		if err == io.EOF && !r.readOneFrame {
 174  			err = io.ErrUnexpectedEOF
 175  		}
 176  		return r.wrapError(relativeOffset, err)
 177  	}
 178  
 179  	if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
 180  		if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
 181  			// This is a skippable frame.
 182  			r.blockOffset += int64(relativeOffset) + 4
 183  			if err := r.skipFrame(); err != nil {
 184  				return err
 185  			}
 186  			r.readOneFrame = true
 187  			goto retry
 188  		}
 189  
 190  		return r.makeError(relativeOffset, "invalid magic number")
 191  	}
 192  
 193  	relativeOffset += 4
 194  
 195  	// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
 196  	if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
 197  		return r.wrapNonEOFError(relativeOffset, err)
 198  	}
 199  	descriptor := r.scratch[0]
 200  
 201  	singleSegment := descriptor&(1<<5) != 0
 202  
 203  	fcsFieldSize := 1 << (descriptor >> 6)
 204  	if fcsFieldSize == 1 && !singleSegment {
 205  		fcsFieldSize = 0
 206  	}
 207  
 208  	var windowDescriptorSize int
 209  	if singleSegment {
 210  		windowDescriptorSize = 0
 211  	} else {
 212  		windowDescriptorSize = 1
 213  	}
 214  
 215  	if descriptor&(1<<3) != 0 {
 216  		return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
 217  	}
 218  
 219  	r.hasChecksum = descriptor&(1<<2) != 0
 220  	if r.hasChecksum {
 221  		r.checksum.reset()
 222  	}
 223  
 224  	// Dictionary_ID_Flag. RFC 3.1.1.1.1.6.
 225  	dictionaryIdSize := 0
 226  	if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
 227  		dictionaryIdSize = 1 << (dictIdFlag - 1)
 228  	}
 229  
 230  	relativeOffset++
 231  
 232  	headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
 233  
 234  	if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
 235  		return r.wrapNonEOFError(relativeOffset, err)
 236  	}
 237  
 238  	// Figure out the maximum amount of data we need to retain
 239  	// for backreferences.
 240  	var windowSize uint64
 241  	if !singleSegment {
 242  		// Window descriptor. RFC 3.1.1.1.2.
 243  		windowDescriptor := r.scratch[0]
 244  		exponent := uint64(windowDescriptor >> 3)
 245  		mantissa := uint64(windowDescriptor & 7)
 246  		windowLog := exponent + 10
 247  		windowBase := uint64(1) << windowLog
 248  		windowAdd := (windowBase / 8) * mantissa
 249  		windowSize = windowBase + windowAdd
 250  
 251  		// Default zstd sets limits on the window size.
 252  		if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
 253  			return r.makeError(relativeOffset, "windowSize too large")
 254  		}
 255  	}
 256  
 257  	// Dictionary_ID. RFC 3.1.1.1.3.
 258  	if dictionaryIdSize != 0 {
 259  		dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
 260  		// Allow only zero Dictionary ID.
 261  		for _, b := range dictionaryId {
 262  			if b != 0 {
 263  				return r.makeError(relativeOffset, "dictionaries are not supported")
 264  			}
 265  		}
 266  	}
 267  
 268  	// Frame_Content_Size. RFC 3.1.1.1.4.
 269  	r.frameSizeUnknown = false
 270  	r.remainingFrameSize = 0
 271  	fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
 272  	switch fcsFieldSize {
 273  	case 0:
 274  		r.frameSizeUnknown = true
 275  	case 1:
 276  		r.remainingFrameSize = uint64(fb[0])
 277  	case 2:
 278  		r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
 279  	case 4:
 280  		r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
 281  	case 8:
 282  		r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
 283  	default:
 284  		panic("unreachable")
 285  	}
 286  
 287  	// RFC 3.1.1.1.2.
 288  	// When Single_Segment_Flag is set, Window_Descriptor is not present.
 289  	// In this case, Window_Size is Frame_Content_Size.
 290  	if singleSegment {
 291  		windowSize = r.remainingFrameSize
 292  	}
 293  
 294  	// RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size.
 295  	const maxWindowSize = 8 << 20
 296  	if windowSize > maxWindowSize {
 297  		windowSize = maxWindowSize
 298  	}
 299  
 300  	relativeOffset += headerSize
 301  
 302  	r.sawFrameHeader = true
 303  	r.readOneFrame = true
 304  	r.blockOffset += int64(relativeOffset)
 305  
 306  	// Prepare to read blocks from the frame.
 307  	r.repeatedOffset1 = 1
 308  	r.repeatedOffset2 = 4
 309  	r.repeatedOffset3 = 8
 310  	r.huffmanTableBits = 0
 311  	r.window.reset(int(windowSize))
 312  	r.seqTables[0] = nil
 313  	r.seqTables[1] = nil
 314  	r.seqTables[2] = nil
 315  
 316  	return nil
 317  }
 318  
 319  // skipFrame skips a skippable frame. RFC 3.1.2.
 320  func (r *Reader) skipFrame() error {
 321  	relativeOffset := 0
 322  
 323  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
 324  		return r.wrapNonEOFError(relativeOffset, err)
 325  	}
 326  
 327  	relativeOffset += 4
 328  
 329  	size := binary.LittleEndian.Uint32(r.scratch[:4])
 330  	if size == 0 {
 331  		r.blockOffset += int64(relativeOffset)
 332  		return nil
 333  	}
 334  
 335  	if seeker, ok := r.r.(io.Seeker); ok {
 336  		r.blockOffset += int64(relativeOffset)
 337  		// Implementations of Seeker do not always detect invalid offsets,
 338  		// so check that the new offset is valid by comparing to the end.
 339  		prev, err := seeker.Seek(0, io.SeekCurrent)
 340  		if err != nil {
 341  			return r.wrapError(0, err)
 342  		}
 343  		end, err := seeker.Seek(0, io.SeekEnd)
 344  		if err != nil {
 345  			return r.wrapError(0, err)
 346  		}
 347  		if prev > end-int64(size) {
 348  			r.blockOffset += end - prev
 349  			return r.makeEOFError(0)
 350  		}
 351  
 352  		// The new offset is valid, so seek to it.
 353  		_, err = seeker.Seek(prev+int64(size), io.SeekStart)
 354  		if err != nil {
 355  			return r.wrapError(0, err)
 356  		}
 357  		r.blockOffset += int64(size)
 358  		return nil
 359  	}
 360  
 361  	n, err := io.CopyN(io.Discard, r.r, int64(size))
 362  	relativeOffset += int(n)
 363  	if err != nil {
 364  		return r.wrapNonEOFError(relativeOffset, err)
 365  	}
 366  	r.blockOffset += int64(relativeOffset)
 367  	return nil
 368  }
 369  
 370  // readBlock reads the next block from a frame.
 371  func (r *Reader) readBlock() error {
 372  	relativeOffset := 0
 373  
 374  	// Read Block_Header. RFC 3.1.1.2.
 375  	if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
 376  		return r.wrapNonEOFError(relativeOffset, err)
 377  	}
 378  
 379  	relativeOffset += 3
 380  
 381  	header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
 382  
 383  	lastBlock := header&1 != 0
 384  	blockType := (header >> 1) & 3
 385  	blockSize := int(header >> 3)
 386  
 387  	// Maximum block size is smaller of window size and 128K.
 388  	// We don't record the window size for a single segment frame,
 389  	// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
 390  	if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
 391  		return r.makeError(relativeOffset, "block size too large")
 392  	}
 393  
 394  	// Handle different block types. RFC 3.1.1.2.2.
 395  	switch blockType {
 396  	case 0:
 397  		r.setBufferSize(blockSize)
 398  		if _, err := io.ReadFull(r.r, r.buffer); err != nil {
 399  			return r.wrapNonEOFError(relativeOffset, err)
 400  		}
 401  		relativeOffset += blockSize
 402  		r.blockOffset += int64(relativeOffset)
 403  	case 1:
 404  		r.setBufferSize(blockSize)
 405  		if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
 406  			return r.wrapNonEOFError(relativeOffset, err)
 407  		}
 408  		relativeOffset++
 409  		v := r.scratch[0]
 410  		for i := range r.buffer {
 411  			r.buffer[i] = v
 412  		}
 413  		r.blockOffset += int64(relativeOffset)
 414  	case 2:
 415  		r.blockOffset += int64(relativeOffset)
 416  		if err := r.compressedBlock(blockSize); err != nil {
 417  			return err
 418  		}
 419  		r.blockOffset += int64(blockSize)
 420  	case 3:
 421  		return r.makeError(relativeOffset, "invalid block type")
 422  	}
 423  
 424  	if !r.frameSizeUnknown {
 425  		if uint64(len(r.buffer)) > r.remainingFrameSize {
 426  			return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
 427  		}
 428  		r.remainingFrameSize -= uint64(len(r.buffer))
 429  	}
 430  
 431  	if r.hasChecksum {
 432  		r.checksum.update(r.buffer)
 433  	}
 434  
 435  	if !lastBlock {
 436  		r.window.save(r.buffer)
 437  	} else {
 438  		if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
 439  			return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
 440  		}
 441  		// Check for checksum at end of frame. RFC 3.1.1.
 442  		if r.hasChecksum {
 443  			if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
 444  				return r.wrapNonEOFError(0, err)
 445  			}
 446  
 447  			inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
 448  			dataChecksum := uint32(r.checksum.digest())
 449  			if inputChecksum != dataChecksum {
 450  				return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
 451  			}
 452  
 453  			r.blockOffset += 4
 454  		}
 455  		r.sawFrameHeader = false
 456  	}
 457  
 458  	return nil
 459  }
 460  
 461  // setBufferSize sets the decompressed buffer size.
 462  // When this is called the buffer is empty.
 463  func (r *Reader) setBufferSize(size int) {
 464  	if cap(r.buffer) < size {
 465  		need := size - cap(r.buffer)
 466  		r.buffer = append(r.buffer[:cap(r.buffer)], []byte{:need}...)
 467  	}
 468  	r.buffer = r.buffer[:size]
 469  }
 470  
 471  // zstdError is an error while decompressing.
 472  type zstdError struct {
 473  	offset int64
 474  	err    error
 475  }
 476  
 477  func (ze *zstdError) Error() string {
 478  	return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
 479  }
 480  
 481  func (ze *zstdError) Unwrap() error {
 482  	return ze.err
 483  }
 484  
 485  func (r *Reader) makeEOFError(off int) error {
 486  	return r.wrapError(off, io.ErrUnexpectedEOF)
 487  }
 488  
 489  func (r *Reader) wrapNonEOFError(off int, err error) error {
 490  	if err == io.EOF {
 491  		err = io.ErrUnexpectedEOF
 492  	}
 493  	return r.wrapError(off, err)
 494  }
 495  
 496  func (r *Reader) makeError(off int, msg []byte) error {
 497  	return r.wrapError(off, errors.New(msg))
 498  }
 499  
 500  func (r *Reader) wrapError(off int, err error) error {
 501  	if err == io.EOF {
 502  		return err
 503  	}
 504  	return &zstdError{r.blockOffset + int64(off), err}
 505  }
 506