snappy.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"encoding/binary"
   9  	"errors"
  10  	"hash/crc32"
  11  	"io"
  12  
  13  	"github.com/klauspost/compress/huff0"
  14  	snappy "github.com/klauspost/compress/internal/snapref"
  15  )
  16  
  17  const (
  18  	snappyTagLiteral = 0x00
  19  	snappyTagCopy1   = 0x01
  20  	snappyTagCopy2   = 0x02
  21  	snappyTagCopy4   = 0x03
  22  )
  23  
  24  const (
  25  	snappyChecksumSize = 4
  26  	snappyMagicBody    = "sNaPpY"
  27  
  28  	// snappyMaxBlockSize is the maximum size of the input to encodeBlock. It is not
  29  	// part of the wire format per se, but some parts of the encoder assume
  30  	// that an offset fits into a uint16.
  31  	//
  32  	// Also, for the framing format (Writer type instead of Encode function),
  33  	// https://github.com/google/snappy/blob/master/framing_format.txt says
  34  	// that "the uncompressed data in a chunk must be no longer than 65536
  35  	// bytes".
  36  	snappyMaxBlockSize = 65536
  37  
  38  	// snappyMaxEncodedLenOfMaxBlockSize equals MaxEncodedLen(snappyMaxBlockSize), but is
  39  	// hard coded to be a const instead of a variable, so that obufLen can also
  40  	// be a const. Their equivalence is confirmed by
  41  	// TestMaxEncodedLenOfMaxBlockSize.
  42  	snappyMaxEncodedLenOfMaxBlockSize = 76490
  43  )
  44  
  45  const (
  46  	chunkTypeCompressedData   = 0x00
  47  	chunkTypeUncompressedData = 0x01
  48  	chunkTypePadding          = 0xfe
  49  	chunkTypeStreamIdentifier = 0xff
  50  )
  51  
  52  var (
  53  	// ErrSnappyCorrupt reports that the input is invalid.
  54  	ErrSnappyCorrupt = errors.New("snappy: corrupt input")
  55  	// ErrSnappyTooLarge reports that the uncompressed length is too large.
  56  	ErrSnappyTooLarge = errors.New("snappy: decoded block is too large")
  57  	// ErrSnappyUnsupported reports that the input isn't supported.
  58  	ErrSnappyUnsupported = errors.New("snappy: unsupported input")
  59  
  60  	errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
  61  )
  62  
  63  // SnappyConverter can read SnappyConverter-compressed streams and convert them to zstd.
  64  // Conversion is done by converting the stream directly from Snappy without intermediate
  65  // full decoding.
  66  // Therefore the compression ratio is much less than what can be done by a full decompression
  67  // and compression, and a faulty Snappy stream may lead to a faulty Zstandard stream without
  68  // any errors being generated.
  69  // No CRC value is being generated and not all CRC values of the Snappy stream are checked.
  70  // However, it provides really fast recompression of Snappy streams.
  71  // The converter can be reused to avoid allocations, even after errors.
  72  type SnappyConverter struct {
  73  	r     io.Reader
  74  	err   error
  75  	buf   []byte
  76  	block *blockEnc
  77  }
  78  
  79  // Convert the Snappy stream supplied in 'in' and write the zStandard stream to 'w'.
  80  // If any error is detected on the Snappy stream it is returned.
  81  // The number of bytes written is returned.
  82  func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
  83  	initPredefined()
  84  	r.err = nil
  85  	r.r = in
  86  	if r.block == nil {
  87  		r.block = &blockEnc{}
  88  		r.block.init()
  89  	}
  90  	r.block.initNewEncode()
  91  	if len(r.buf) != snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize {
  92  		r.buf = make([]byte, snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize)
  93  	}
  94  	r.block.litEnc.Reuse = huff0.ReusePolicyNone
  95  	var written int64
  96  	var readHeader bool
  97  	{
  98  		header := frameHeader{WindowSize: snappyMaxBlockSize}.appendTo(r.buf[:0])
  99  
 100  		var n int
 101  		n, r.err = w.Write(header)
 102  		if r.err != nil {
 103  			return written, r.err
 104  		}
 105  		written += int64(n)
 106  	}
 107  
 108  	for {
 109  		if !r.readFull(r.buf[:4], true) {
 110  			// Add empty last block
 111  			r.block.reset(nil)
 112  			r.block.last = true
 113  			err := r.block.encodeLits(r.block.literals, false)
 114  			if err != nil {
 115  				return written, err
 116  			}
 117  			n, err := w.Write(r.block.output)
 118  			if err != nil {
 119  				return written, err
 120  			}
 121  			written += int64(n)
 122  
 123  			return written, r.err
 124  		}
 125  		chunkType := r.buf[0]
 126  		if !readHeader {
 127  			if chunkType != chunkTypeStreamIdentifier {
 128  				println("chunkType != chunkTypeStreamIdentifier", chunkType)
 129  				r.err = ErrSnappyCorrupt
 130  				return written, r.err
 131  			}
 132  			readHeader = true
 133  		}
 134  		chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
 135  		if chunkLen > len(r.buf) {
 136  			println("chunkLen > len(r.buf)", chunkType)
 137  			r.err = ErrSnappyUnsupported
 138  			return written, r.err
 139  		}
 140  
 141  		// The chunk types are specified at
 142  		// https://github.com/google/snappy/blob/master/framing_format.txt
 143  		switch chunkType {
 144  		case chunkTypeCompressedData:
 145  			// Section 4.2. Compressed data (chunk type 0x00).
 146  			if chunkLen < snappyChecksumSize {
 147  				println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
 148  				r.err = ErrSnappyCorrupt
 149  				return written, r.err
 150  			}
 151  			buf := r.buf[:chunkLen]
 152  			if !r.readFull(buf, false) {
 153  				return written, r.err
 154  			}
 155  			//checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
 156  			buf = buf[snappyChecksumSize:]
 157  
 158  			n, hdr, err := snappyDecodedLen(buf)
 159  			if err != nil {
 160  				r.err = err
 161  				return written, r.err
 162  			}
 163  			buf = buf[hdr:]
 164  			if n > snappyMaxBlockSize {
 165  				println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
 166  				r.err = ErrSnappyCorrupt
 167  				return written, r.err
 168  			}
 169  			r.block.reset(nil)
 170  			r.block.pushOffsets()
 171  			if err := decodeSnappy(r.block, buf); err != nil {
 172  				r.err = err
 173  				return written, r.err
 174  			}
 175  			if r.block.size+r.block.extraLits != n {
 176  				printf("invalid size, want %d, got %d\n", n, r.block.size+r.block.extraLits)
 177  				r.err = ErrSnappyCorrupt
 178  				return written, r.err
 179  			}
 180  			err = r.block.encode(nil, false, false)
 181  			switch err {
 182  			case errIncompressible:
 183  				r.block.popOffsets()
 184  				r.block.reset(nil)
 185  				r.block.literals, err = snappy.Decode(r.block.literals[:n], r.buf[snappyChecksumSize:chunkLen])
 186  				if err != nil {
 187  					return written, err
 188  				}
 189  				err = r.block.encodeLits(r.block.literals, false)
 190  				if err != nil {
 191  					return written, err
 192  				}
 193  			case nil:
 194  			default:
 195  				return written, err
 196  			}
 197  
 198  			n, r.err = w.Write(r.block.output)
 199  			if r.err != nil {
 200  				return written, r.err
 201  			}
 202  			written += int64(n)
 203  			continue
 204  		case chunkTypeUncompressedData:
 205  			if debugEncoder {
 206  				println("Uncompressed, chunklen", chunkLen)
 207  			}
 208  			// Section 4.3. Uncompressed data (chunk type 0x01).
 209  			if chunkLen < snappyChecksumSize {
 210  				println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
 211  				r.err = ErrSnappyCorrupt
 212  				return written, r.err
 213  			}
 214  			r.block.reset(nil)
 215  			buf := r.buf[:snappyChecksumSize]
 216  			if !r.readFull(buf, false) {
 217  				return written, r.err
 218  			}
 219  			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
 220  			// Read directly into r.decoded instead of via r.buf.
 221  			n := chunkLen - snappyChecksumSize
 222  			if n > snappyMaxBlockSize {
 223  				println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
 224  				r.err = ErrSnappyCorrupt
 225  				return written, r.err
 226  			}
 227  			r.block.literals = r.block.literals[:n]
 228  			if !r.readFull(r.block.literals, false) {
 229  				return written, r.err
 230  			}
 231  			if snappyCRC(r.block.literals) != checksum {
 232  				println("literals crc mismatch")
 233  				r.err = ErrSnappyCorrupt
 234  				return written, r.err
 235  			}
 236  			err := r.block.encodeLits(r.block.literals, false)
 237  			if err != nil {
 238  				return written, err
 239  			}
 240  			n, r.err = w.Write(r.block.output)
 241  			if r.err != nil {
 242  				return written, r.err
 243  			}
 244  			written += int64(n)
 245  			continue
 246  
 247  		case chunkTypeStreamIdentifier:
 248  			if debugEncoder {
 249  				println("stream id", chunkLen, len(snappyMagicBody))
 250  			}
 251  			// Section 4.1. Stream identifier (chunk type 0xff).
 252  			if chunkLen != len(snappyMagicBody) {
 253  				println("chunkLen != len(snappyMagicBody)", chunkLen, len(snappyMagicBody))
 254  				r.err = ErrSnappyCorrupt
 255  				return written, r.err
 256  			}
 257  			if !r.readFull(r.buf[:len(snappyMagicBody)], false) {
 258  				return written, r.err
 259  			}
 260  			for i := range len(snappyMagicBody) {
 261  				if r.buf[i] != snappyMagicBody[i] {
 262  					println("r.buf[i] != snappyMagicBody[i]", r.buf[i], snappyMagicBody[i], i)
 263  					r.err = ErrSnappyCorrupt
 264  					return written, r.err
 265  				}
 266  			}
 267  			continue
 268  		}
 269  
 270  		if chunkType <= 0x7f {
 271  			// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
 272  			println("chunkType <= 0x7f")
 273  			r.err = ErrSnappyUnsupported
 274  			return written, r.err
 275  		}
 276  		// Section 4.4 Padding (chunk type 0xfe).
 277  		// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
 278  		if !r.readFull(r.buf[:chunkLen], false) {
 279  			return written, r.err
 280  		}
 281  	}
 282  }
 283  
 284  // decodeSnappy writes the decoding of src to dst. It assumes that the varint-encoded
 285  // length of the decompressed bytes has already been read.
 286  func decodeSnappy(blk *blockEnc, src []byte) error {
 287  	//decodeRef(make([]byte, snappyMaxBlockSize), src)
 288  	var s, length int
 289  	lits := blk.extraLits
 290  	var offset uint32
 291  	for s < len(src) {
 292  		switch src[s] & 0x03 {
 293  		case snappyTagLiteral:
 294  			x := uint32(src[s] >> 2)
 295  			switch {
 296  			case x < 60:
 297  				s++
 298  			case x == 60:
 299  				s += 2
 300  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 301  					println("uint(s) > uint(len(src)", s, src)
 302  					return ErrSnappyCorrupt
 303  				}
 304  				x = uint32(src[s-1])
 305  			case x == 61:
 306  				s += 3
 307  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 308  					println("uint(s) > uint(len(src)", s, src)
 309  					return ErrSnappyCorrupt
 310  				}
 311  				x = uint32(src[s-2]) | uint32(src[s-1])<<8
 312  			case x == 62:
 313  				s += 4
 314  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 315  					println("uint(s) > uint(len(src)", s, src)
 316  					return ErrSnappyCorrupt
 317  				}
 318  				x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
 319  			case x == 63:
 320  				s += 5
 321  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 322  					println("uint(s) > uint(len(src)", s, src)
 323  					return ErrSnappyCorrupt
 324  				}
 325  				x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
 326  			}
 327  			if x > snappyMaxBlockSize {
 328  				println("x > snappyMaxBlockSize", x, snappyMaxBlockSize)
 329  				return ErrSnappyCorrupt
 330  			}
 331  			length = int(x) + 1
 332  			if length <= 0 {
 333  				println("length <= 0 ", length)
 334  
 335  				return errUnsupportedLiteralLength
 336  			}
 337  			//if length > snappyMaxBlockSize-d || uint32(length) > len(src)-s {
 338  			//	return ErrSnappyCorrupt
 339  			//}
 340  
 341  			blk.literals = append(blk.literals, src[s:s+length]...)
 342  			//println(length, "litLen")
 343  			lits += length
 344  			s += length
 345  			continue
 346  
 347  		case snappyTagCopy1:
 348  			s += 2
 349  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 350  				println("uint(s) > uint(len(src)", s, len(src))
 351  				return ErrSnappyCorrupt
 352  			}
 353  			length = 4 + int(src[s-2])>>2&0x7
 354  			offset = uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])
 355  
 356  		case snappyTagCopy2:
 357  			s += 3
 358  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 359  				println("uint(s) > uint(len(src)", s, len(src))
 360  				return ErrSnappyCorrupt
 361  			}
 362  			length = 1 + int(src[s-3])>>2
 363  			offset = uint32(src[s-2]) | uint32(src[s-1])<<8
 364  
 365  		case snappyTagCopy4:
 366  			s += 5
 367  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 368  				println("uint(s) > uint(len(src)", s, len(src))
 369  				return ErrSnappyCorrupt
 370  			}
 371  			length = 1 + int(src[s-5])>>2
 372  			offset = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
 373  		}
 374  
 375  		if offset <= 0 || blk.size+lits < int(offset) /*|| length > len(blk)-d */ {
 376  			println("offset <= 0 || blk.size+lits < int(offset)", offset, blk.size+lits, int(offset), blk.size, lits)
 377  
 378  			return ErrSnappyCorrupt
 379  		}
 380  
 381  		// Check if offset is one of the recent offsets.
 382  		// Adjusts the output offset accordingly.
 383  		// Gives a tiny bit of compression, typically around 1%.
 384  		if false {
 385  			offset = blk.matchOffset(offset, uint32(lits))
 386  		} else {
 387  			offset += 3
 388  		}
 389  
 390  		blk.sequences = append(blk.sequences, seq{
 391  			litLen:   uint32(lits),
 392  			offset:   offset,
 393  			matchLen: uint32(length) - zstdMinMatch,
 394  		})
 395  		blk.size += length + lits
 396  		lits = 0
 397  	}
 398  	blk.extraLits = lits
 399  	return nil
 400  }
 401  
 402  func (r *SnappyConverter) readFull(p []byte, allowEOF bool) (ok bool) {
 403  	if _, r.err = io.ReadFull(r.r, p); r.err != nil {
 404  		if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
 405  			r.err = ErrSnappyCorrupt
 406  		}
 407  		return false
 408  	}
 409  	return true
 410  }
 411  
 412  var crcTable = crc32.MakeTable(crc32.Castagnoli)
 413  
 414  // crc implements the checksum specified in section 3 of
 415  // https://github.com/google/snappy/blob/master/framing_format.txt
 416  func snappyCRC(b []byte) uint32 {
 417  	c := crc32.Update(0, crcTable, b)
 418  	return c>>15 | c<<17 + 0xa282ead8
 419  }
 420  
 421  // snappyDecodedLen returns the length of the decoded block and the number of bytes
 422  // that the length header occupied.
 423  func snappyDecodedLen(src []byte) (blockLen, headerLen int, err error) {
 424  	v, n := binary.Uvarint(src)
 425  	if n <= 0 || v > 0xffffffff {
 426  		return 0, 0, ErrSnappyCorrupt
 427  	}
 428  
 429  	const wordSize = 32 << (^uint(0) >> 32 & 1)
 430  	if wordSize == 32 && v > 0x7fffffff {
 431  		return 0, 0, ErrSnappyTooLarge
 432  	}
 433  	return int(v), n, nil
 434  }
 435