framedec.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"encoding/binary"
   9  	"encoding/hex"
  10  	"errors"
  11  	"io"
  12  
  13  	"github.com/klauspost/compress/zstd/internal/xxhash"
  14  )
  15  
  16  type frameDec struct {
  17  	o   decoderOptions
  18  	crc *xxhash.Digest
  19  
  20  	WindowSize uint64
  21  
  22  	// Frame history passed between blocks
  23  	history history
  24  
  25  	rawInput byteBuffer
  26  
  27  	// Byte buffer that can be reused for small input blocks.
  28  	bBuf byteBuf
  29  
  30  	FrameContentSize uint64
  31  
  32  	DictionaryID  uint32
  33  	HasCheckSum   bool
  34  	SingleSegment bool
  35  }
  36  
  37  const (
  38  	// MinWindowSize is the minimum Window Size, which is 1 KB.
  39  	MinWindowSize = 1 << 10
  40  
  41  	// MaxWindowSize is the maximum encoder window size
  42  	// and the default decoder maximum window size.
  43  	MaxWindowSize = 1 << 29
  44  )
  45  
  46  const (
  47  	frameMagic          = "\x28\xb5\x2f\xfd"
  48  	skippableFrameMagic = "\x2a\x4d\x18"
  49  )
  50  
  51  func newFrameDec(o decoderOptions) *frameDec {
  52  	if o.maxWindowSize > o.maxDecodedSize {
  53  		o.maxWindowSize = o.maxDecodedSize
  54  	}
  55  	d := frameDec{
  56  		o: o,
  57  	}
  58  	return &d
  59  }
  60  
  61  // reset will read the frame header and prepare for block decoding.
  62  // If nothing can be read from the input, io.EOF will be returned.
  63  // Any other error indicated that the stream contained data, but
  64  // there was a problem.
  65  func (d *frameDec) reset(br byteBuffer) error {
  66  	d.HasCheckSum = false
  67  	d.WindowSize = 0
  68  	var signature [4]byte
  69  	for {
  70  		var err error
  71  		// Check if we can read more...
  72  		b, err := br.readSmall(1)
  73  		switch err {
  74  		case io.EOF, io.ErrUnexpectedEOF:
  75  			return io.EOF
  76  		case nil:
  77  			signature[0] = b[0]
  78  		default:
  79  			return err
  80  		}
  81  		// Read the rest, don't allow io.ErrUnexpectedEOF
  82  		b, err = br.readSmall(3)
  83  		switch err {
  84  		case io.EOF:
  85  			return io.EOF
  86  		case nil:
  87  			copy(signature[1:], b)
  88  		default:
  89  			return err
  90  		}
  91  
  92  		if string(signature[1:4]) != skippableFrameMagic || signature[0]&0xf0 != 0x50 {
  93  			if debugDecoder {
  94  				println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString([]byte(skippableFrameMagic)))
  95  			}
  96  			// Break if not skippable frame.
  97  			break
  98  		}
  99  		// Read size to skip
 100  		b, err = br.readSmall(4)
 101  		if err != nil {
 102  			if debugDecoder {
 103  				println("Reading Frame Size", err)
 104  			}
 105  			return err
 106  		}
 107  		n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
 108  		println("Skipping frame with", n, "bytes.")
 109  		err = br.skipN(int64(n))
 110  		if err != nil {
 111  			if debugDecoder {
 112  				println("Reading discarded frame", err)
 113  			}
 114  			return err
 115  		}
 116  	}
 117  	if string(signature[:]) != frameMagic {
 118  		if debugDecoder {
 119  			println("Got magic numbers: ", signature, "want:", []byte(frameMagic))
 120  		}
 121  		return ErrMagicMismatch
 122  	}
 123  
 124  	// Read Frame_Header_Descriptor
 125  	fhd, err := br.readByte()
 126  	if err != nil {
 127  		if debugDecoder {
 128  			println("Reading Frame_Header_Descriptor", err)
 129  		}
 130  		return err
 131  	}
 132  	d.SingleSegment = fhd&(1<<5) != 0
 133  
 134  	if fhd&(1<<3) != 0 {
 135  		return errors.New("reserved bit set on frame header")
 136  	}
 137  
 138  	// Read Window_Descriptor
 139  	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
 140  	d.WindowSize = 0
 141  	if !d.SingleSegment {
 142  		wd, err := br.readByte()
 143  		if err != nil {
 144  			if debugDecoder {
 145  				println("Reading Window_Descriptor", err)
 146  			}
 147  			return err
 148  		}
 149  		if debugDecoder {
 150  			printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
 151  		}
 152  		windowLog := 10 + (wd >> 3)
 153  		windowBase := uint64(1) << windowLog
 154  		windowAdd := (windowBase / 8) * uint64(wd&0x7)
 155  		d.WindowSize = windowBase + windowAdd
 156  	}
 157  
 158  	// Read Dictionary_ID
 159  	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
 160  	d.DictionaryID = 0
 161  	if size := fhd & 3; size != 0 {
 162  		if size == 3 {
 163  			size = 4
 164  		}
 165  
 166  		b, err := br.readSmall(int(size))
 167  		if err != nil {
 168  			println("Reading Dictionary_ID", err)
 169  			return err
 170  		}
 171  		var id uint32
 172  		switch len(b) {
 173  		case 1:
 174  			id = uint32(b[0])
 175  		case 2:
 176  			id = uint32(b[0]) | (uint32(b[1]) << 8)
 177  		case 4:
 178  			id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
 179  		}
 180  		if debugDecoder {
 181  			println("Dict size", size, "ID:", id)
 182  		}
 183  		d.DictionaryID = id
 184  	}
 185  
 186  	// Read Frame_Content_Size
 187  	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
 188  	var fcsSize int
 189  	v := fhd >> 6
 190  	switch v {
 191  	case 0:
 192  		if d.SingleSegment {
 193  			fcsSize = 1
 194  		}
 195  	default:
 196  		fcsSize = 1 << v
 197  	}
 198  	d.FrameContentSize = fcsUnknown
 199  	if fcsSize > 0 {
 200  		b, err := br.readSmall(fcsSize)
 201  		if err != nil {
 202  			println("Reading Frame content", err)
 203  			return err
 204  		}
 205  		switch len(b) {
 206  		case 1:
 207  			d.FrameContentSize = uint64(b[0])
 208  		case 2:
 209  			// When FCS_Field_Size is 2, the offset of 256 is added.
 210  			d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
 211  		case 4:
 212  			d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
 213  		case 8:
 214  			d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
 215  			d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
 216  			d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
 217  		}
 218  		if debugDecoder {
 219  			println("Read FCS:", d.FrameContentSize)
 220  		}
 221  	}
 222  
 223  	// Move this to shared.
 224  	d.HasCheckSum = fhd&(1<<2) != 0
 225  	if d.HasCheckSum {
 226  		if d.crc == nil {
 227  			d.crc = xxhash.New()
 228  		}
 229  		d.crc.Reset()
 230  	}
 231  
 232  	if d.WindowSize > d.o.maxWindowSize {
 233  		if debugDecoder {
 234  			printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
 235  		}
 236  		return ErrWindowSizeExceeded
 237  	}
 238  
 239  	if d.WindowSize == 0 && d.SingleSegment {
 240  		// We may not need window in this case.
 241  		d.WindowSize = max(d.FrameContentSize, MinWindowSize)
 242  		if d.WindowSize > d.o.maxDecodedSize {
 243  			if debugDecoder {
 244  				printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
 245  			}
 246  			return ErrDecoderSizeExceeded
 247  		}
 248  	}
 249  
 250  	// The minimum Window_Size is 1 KB.
 251  	if d.WindowSize < MinWindowSize {
 252  		if debugDecoder {
 253  			println("got window size: ", d.WindowSize)
 254  		}
 255  		return ErrWindowSizeTooSmall
 256  	}
 257  	d.history.windowSize = int(d.WindowSize)
 258  	if !d.o.lowMem || d.history.windowSize < maxBlockSize {
 259  		// Alloc 2x window size if not low-mem, or window size below 2MB.
 260  		d.history.allocFrameBuffer = d.history.windowSize * 2
 261  	} else {
 262  		if d.o.lowMem {
 263  			// Alloc with 1MB extra.
 264  			d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize/2
 265  		} else {
 266  			// Alloc with 2MB extra.
 267  			d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize
 268  		}
 269  	}
 270  
 271  	if debugDecoder {
 272  		println("Frame: Dict:", d.DictionaryID, "FrameContentSize:", d.FrameContentSize, "singleseg:", d.SingleSegment, "window:", d.WindowSize, "crc:", d.HasCheckSum)
 273  	}
 274  
 275  	// history contains input - maybe we do something
 276  	d.rawInput = br
 277  	return nil
 278  }
 279  
 280  // next will start decoding the next block from stream.
 281  func (d *frameDec) next(block *blockDec) error {
 282  	if debugDecoder {
 283  		println("decoding new block")
 284  	}
 285  	err := block.reset(d.rawInput, d.WindowSize)
 286  	if err != nil {
 287  		println("block error:", err)
 288  		// Signal the frame decoder we have a problem.
 289  		block.sendErr(err)
 290  		return err
 291  	}
 292  	return nil
 293  }
 294  
 295  // checkCRC will check the checksum, assuming the frame has one.
 296  // Will return ErrCRCMismatch if crc check failed, otherwise nil.
 297  func (d *frameDec) checkCRC() error {
 298  	// We can overwrite upper tmp now
 299  	buf, err := d.rawInput.readSmall(4)
 300  	if err != nil {
 301  		println("CRC missing?", err)
 302  		return err
 303  	}
 304  
 305  	want := binary.LittleEndian.Uint32(buf[:4])
 306  	got := uint32(d.crc.Sum64())
 307  
 308  	if got != want {
 309  		if debugDecoder {
 310  			printf("CRC check failed: got %08x, want %08x\n", got, want)
 311  		}
 312  		return ErrCRCMismatch
 313  	}
 314  	if debugDecoder {
 315  		printf("CRC ok %08x\n", got)
 316  	}
 317  	return nil
 318  }
 319  
 320  // consumeCRC skips over the checksum, assuming the frame has one.
 321  func (d *frameDec) consumeCRC() error {
 322  	_, err := d.rawInput.readSmall(4)
 323  	if err != nil {
 324  		println("CRC missing?", err)
 325  	}
 326  	return err
 327  }
 328  
 329  // runDecoder will run the decoder for the remainder of the frame.
 330  func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
 331  	saved := d.history.b
 332  
 333  	// We use the history for output to avoid copying it.
 334  	d.history.b = dst
 335  	d.history.ignoreBuffer = len(dst)
 336  	// Store input length, so we only check new data.
 337  	crcStart := len(dst)
 338  	d.history.decoders.maxSyncLen = 0
 339  	if d.o.limitToCap {
 340  		d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
 341  	}
 342  	if d.FrameContentSize != fcsUnknown {
 343  		if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
 344  			d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
 345  		}
 346  		if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
 347  			if debugDecoder {
 348  				println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
 349  			}
 350  			return dst, ErrDecoderSizeExceeded
 351  		}
 352  		if debugDecoder {
 353  			println("maxSyncLen:", d.history.decoders.maxSyncLen)
 354  		}
 355  		if !d.o.limitToCap && uint64(cap(dst)) < d.history.decoders.maxSyncLen {
 356  			// Alloc for output
 357  			dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
 358  			copy(dst2, dst)
 359  			dst = dst2
 360  		}
 361  	}
 362  	var err error
 363  	for {
 364  		err = dec.reset(d.rawInput, d.WindowSize)
 365  		if err != nil {
 366  			break
 367  		}
 368  		if debugDecoder {
 369  			println("next block:", dec)
 370  		}
 371  		err = dec.decodeBuf(&d.history)
 372  		if err != nil {
 373  			break
 374  		}
 375  		if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize {
 376  			println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize)
 377  			err = ErrDecoderSizeExceeded
 378  			break
 379  		}
 380  		if d.o.limitToCap && len(d.history.b) > cap(dst) {
 381  			println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst))
 382  			err = ErrDecoderSizeExceeded
 383  			break
 384  		}
 385  		if uint64(len(d.history.b)-crcStart) > d.FrameContentSize {
 386  			println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize)
 387  			err = ErrFrameSizeExceeded
 388  			break
 389  		}
 390  		if dec.Last {
 391  			break
 392  		}
 393  		if debugDecoder {
 394  			println("runDecoder: FrameContentSize", uint64(len(d.history.b)-crcStart), "<=", d.FrameContentSize)
 395  		}
 396  	}
 397  	dst = d.history.b
 398  	if err == nil {
 399  		if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize {
 400  			err = ErrFrameSizeMismatch
 401  		} else if d.HasCheckSum {
 402  			if d.o.ignoreChecksum {
 403  				err = d.consumeCRC()
 404  			} else {
 405  				d.crc.Write(dst[crcStart:])
 406  				err = d.checkCRC()
 407  			}
 408  		}
 409  	}
 410  	d.history.b = saved
 411  	return dst, err
 412  }
 413