decode.go raw

   1  // Copyright 2011 The Snappy-Go Authors. All rights reserved.
   2  // Copyright (c) 2019 Klaus Post. All rights reserved.
   3  // Use of this source code is governed by a BSD-style
   4  // license that can be found in the LICENSE file.
   5  
   6  package s2
   7  
   8  import (
   9  	"encoding/binary"
  10  	"errors"
  11  	"fmt"
  12  	"strconv"
  13  
  14  	"github.com/klauspost/compress/internal/race"
  15  )
  16  
  17  var (
  18  	// ErrCorrupt reports that the input is invalid.
  19  	ErrCorrupt = errors.New("s2: corrupt input")
  20  	// ErrCRC reports that the input failed CRC validation (streams only)
  21  	ErrCRC = errors.New("s2: corrupt input, crc mismatch")
  22  	// ErrTooLarge reports that the uncompressed length is too large.
  23  	ErrTooLarge = errors.New("s2: decoded block is too large")
  24  	// ErrUnsupported reports that the input isn't supported.
  25  	ErrUnsupported = errors.New("s2: unsupported input")
  26  )
  27  
  28  // DecodedLen returns the length of the decoded block.
  29  func DecodedLen(src []byte) (int, error) {
  30  	v, _, err := decodedLen(src)
  31  	return v, err
  32  }
  33  
  34  // decodedLen returns the length of the decoded block and the number of bytes
  35  // that the length header occupied.
  36  func decodedLen(src []byte) (blockLen, headerLen int, err error) {
  37  	v, n := binary.Uvarint(src)
  38  	if n <= 0 || v > 0xffffffff {
  39  		return 0, 0, ErrCorrupt
  40  	}
  41  
  42  	const wordSize = 32 << (^uint(0) >> 32 & 1)
  43  	if wordSize == 32 && v > 0x7fffffff {
  44  		return 0, 0, ErrTooLarge
  45  	}
  46  	return int(v), n, nil
  47  }
  48  
  49  const (
  50  	decodeErrCodeCorrupt = 1
  51  )
  52  
  53  // Decode returns the decoded form of src. The returned slice may be a sub-
  54  // slice of dst if dst was large enough to hold the entire decoded block.
  55  // Otherwise, a newly allocated slice will be returned.
  56  //
  57  // The dst and src must not overlap. It is valid to pass a nil dst.
  58  func Decode(dst, src []byte) ([]byte, error) {
  59  	dLen, s, err := decodedLen(src)
  60  	if err != nil {
  61  		return nil, err
  62  	}
  63  	if dLen <= cap(dst) {
  64  		dst = dst[:dLen]
  65  	} else {
  66  		dst = make([]byte, dLen)
  67  	}
  68  
  69  	race.WriteSlice(dst)
  70  	race.ReadSlice(src[s:])
  71  
  72  	if s2Decode(dst, src[s:]) != 0 {
  73  		return nil, ErrCorrupt
  74  	}
  75  	return dst, nil
  76  }
  77  
  78  // s2DecodeDict writes the decoding of src to dst. It assumes that the varint-encoded
  79  // length of the decompressed bytes has already been read, and that len(dst)
  80  // equals that length.
  81  //
  82  // It returns 0 on success or a decodeErrCodeXxx error code on failure.
  83  func s2DecodeDict(dst, src []byte, dict *Dict) int {
  84  	if dict == nil {
  85  		return s2Decode(dst, src)
  86  	}
  87  	const debug = false
  88  	const debugErrs = debug
  89  
  90  	if debug {
  91  		fmt.Println("Starting decode, dst len:", len(dst))
  92  	}
  93  	var d, s, length int
  94  	offset := len(dict.dict) - dict.repeat
  95  
  96  	// As long as we can read at least 5 bytes...
  97  	for s < len(src)-5 {
  98  		// Removing bounds checks is SLOWER, when if doing
  99  		// in := src[s:s+5]
 100  		// Checked on Go 1.18
 101  		switch src[s] & 0x03 {
 102  		case tagLiteral:
 103  			x := uint32(src[s] >> 2)
 104  			switch {
 105  			case x < 60:
 106  				s++
 107  			case x == 60:
 108  				s += 2
 109  				x = uint32(src[s-1])
 110  			case x == 61:
 111  				in := src[s : s+3]
 112  				x = uint32(in[1]) | uint32(in[2])<<8
 113  				s += 3
 114  			case x == 62:
 115  				in := src[s : s+4]
 116  				// Load as 32 bit and shift down.
 117  				x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
 118  				x >>= 8
 119  				s += 4
 120  			case x == 63:
 121  				in := src[s : s+5]
 122  				x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
 123  				s += 5
 124  			}
 125  			length = int(x) + 1
 126  			if debug {
 127  				fmt.Println("literals, length:", length, "d-after:", d+length)
 128  			}
 129  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
 130  				if debugErrs {
 131  					fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
 132  				}
 133  				return decodeErrCodeCorrupt
 134  			}
 135  
 136  			copy(dst[d:], src[s:s+length])
 137  			d += length
 138  			s += length
 139  			continue
 140  
 141  		case tagCopy1:
 142  			s += 2
 143  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
 144  			length = int(src[s-2]) >> 2 & 0x7
 145  			if toffset == 0 {
 146  				if debug {
 147  					fmt.Print("(repeat) ")
 148  				}
 149  				// keep last offset
 150  				switch length {
 151  				case 5:
 152  					length = int(src[s]) + 4
 153  					s += 1
 154  				case 6:
 155  					in := src[s : s+2]
 156  					length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
 157  					s += 2
 158  				case 7:
 159  					in := src[s : s+3]
 160  					length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16)
 161  					s += 3
 162  				default: // 0-> 4
 163  				}
 164  			} else {
 165  				offset = toffset
 166  			}
 167  			length += 4
 168  		case tagCopy2:
 169  			in := src[s : s+3]
 170  			offset = int(uint32(in[1]) | uint32(in[2])<<8)
 171  			length = 1 + int(in[0])>>2
 172  			s += 3
 173  
 174  		case tagCopy4:
 175  			in := src[s : s+5]
 176  			offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
 177  			length = 1 + int(in[0])>>2
 178  			s += 5
 179  		}
 180  
 181  		if offset <= 0 || length > len(dst)-d {
 182  			if debugErrs {
 183  				fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
 184  			}
 185  			return decodeErrCodeCorrupt
 186  		}
 187  
 188  		// copy from dict
 189  		if d < offset {
 190  			if d > MaxDictSrcOffset {
 191  				if debugErrs {
 192  					fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
 193  				}
 194  				return decodeErrCodeCorrupt
 195  			}
 196  			startOff := len(dict.dict) - offset + d
 197  			if startOff < 0 || startOff+length > len(dict.dict) {
 198  				if debugErrs {
 199  					fmt.Printf("offset (%d) + length (%d) bigger than dict (%d)\n", offset, length, len(dict.dict))
 200  				}
 201  				return decodeErrCodeCorrupt
 202  			}
 203  			if debug {
 204  				fmt.Println("dict copy, length:", length, "offset:", offset, "d-after:", d+length, "dict start offset:", startOff)
 205  			}
 206  			copy(dst[d:d+length], dict.dict[startOff:])
 207  			d += length
 208  			continue
 209  		}
 210  
 211  		if debug {
 212  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
 213  		}
 214  
 215  		// Copy from an earlier sub-slice of dst to a later sub-slice.
 216  		// If no overlap, use the built-in copy:
 217  		if offset > length {
 218  			copy(dst[d:d+length], dst[d-offset:])
 219  			d += length
 220  			continue
 221  		}
 222  
 223  		// Unlike the built-in copy function, this byte-by-byte copy always runs
 224  		// forwards, even if the slices overlap. Conceptually, this is:
 225  		//
 226  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
 227  		//
 228  		// We align the slices into a and b and show the compiler they are the same size.
 229  		// This allows the loop to run without bounds checks.
 230  		a := dst[d : d+length]
 231  		b := dst[d-offset:]
 232  		b = b[:len(a)]
 233  		for i := range a {
 234  			a[i] = b[i]
 235  		}
 236  		d += length
 237  	}
 238  
 239  	// Remaining with extra checks...
 240  	for s < len(src) {
 241  		switch src[s] & 0x03 {
 242  		case tagLiteral:
 243  			x := uint32(src[s] >> 2)
 244  			switch {
 245  			case x < 60:
 246  				s++
 247  			case x == 60:
 248  				s += 2
 249  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 250  					if debugErrs {
 251  						fmt.Println("src went oob")
 252  					}
 253  					return decodeErrCodeCorrupt
 254  				}
 255  				x = uint32(src[s-1])
 256  			case x == 61:
 257  				s += 3
 258  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 259  					if debugErrs {
 260  						fmt.Println("src went oob")
 261  					}
 262  					return decodeErrCodeCorrupt
 263  				}
 264  				x = uint32(src[s-2]) | uint32(src[s-1])<<8
 265  			case x == 62:
 266  				s += 4
 267  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 268  					if debugErrs {
 269  						fmt.Println("src went oob")
 270  					}
 271  					return decodeErrCodeCorrupt
 272  				}
 273  				x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
 274  			case x == 63:
 275  				s += 5
 276  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 277  					if debugErrs {
 278  						fmt.Println("src went oob")
 279  					}
 280  					return decodeErrCodeCorrupt
 281  				}
 282  				x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
 283  			}
 284  			length = int(x) + 1
 285  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
 286  				if debugErrs {
 287  					fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
 288  				}
 289  				return decodeErrCodeCorrupt
 290  			}
 291  			if debug {
 292  				fmt.Println("literals, length:", length, "d-after:", d+length)
 293  			}
 294  
 295  			copy(dst[d:], src[s:s+length])
 296  			d += length
 297  			s += length
 298  			continue
 299  
 300  		case tagCopy1:
 301  			s += 2
 302  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 303  				if debugErrs {
 304  					fmt.Println("src went oob")
 305  				}
 306  				return decodeErrCodeCorrupt
 307  			}
 308  			length = int(src[s-2]) >> 2 & 0x7
 309  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
 310  			if toffset == 0 {
 311  				if debug {
 312  					fmt.Print("(repeat) ")
 313  				}
 314  				// keep last offset
 315  				switch length {
 316  				case 5:
 317  					s += 1
 318  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 319  						if debugErrs {
 320  							fmt.Println("src went oob")
 321  						}
 322  						return decodeErrCodeCorrupt
 323  					}
 324  					length = int(uint32(src[s-1])) + 4
 325  				case 6:
 326  					s += 2
 327  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 328  						if debugErrs {
 329  							fmt.Println("src went oob")
 330  						}
 331  						return decodeErrCodeCorrupt
 332  					}
 333  					length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8)
 334  				case 7:
 335  					s += 3
 336  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 337  						if debugErrs {
 338  							fmt.Println("src went oob")
 339  						}
 340  						return decodeErrCodeCorrupt
 341  					}
 342  					length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16)
 343  				default: // 0-> 4
 344  				}
 345  			} else {
 346  				offset = toffset
 347  			}
 348  			length += 4
 349  		case tagCopy2:
 350  			s += 3
 351  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 352  				if debugErrs {
 353  					fmt.Println("src went oob")
 354  				}
 355  				return decodeErrCodeCorrupt
 356  			}
 357  			length = 1 + int(src[s-3])>>2
 358  			offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
 359  
 360  		case tagCopy4:
 361  			s += 5
 362  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 363  				if debugErrs {
 364  					fmt.Println("src went oob")
 365  				}
 366  				return decodeErrCodeCorrupt
 367  			}
 368  			length = 1 + int(src[s-5])>>2
 369  			offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
 370  		}
 371  
 372  		if offset <= 0 || length > len(dst)-d {
 373  			if debugErrs {
 374  				fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
 375  			}
 376  			return decodeErrCodeCorrupt
 377  		}
 378  
 379  		// copy from dict
 380  		if d < offset {
 381  			if d > MaxDictSrcOffset {
 382  				if debugErrs {
 383  					fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
 384  				}
 385  				return decodeErrCodeCorrupt
 386  			}
 387  			rOff := len(dict.dict) - (offset - d)
 388  			if debug {
 389  				fmt.Println("starting dict entry from dict offset", len(dict.dict)-rOff)
 390  			}
 391  			if rOff+length > len(dict.dict) {
 392  				if debugErrs {
 393  					fmt.Println("err: END offset", rOff+length, "bigger than dict", len(dict.dict), "dict offset:", rOff, "length:", length)
 394  				}
 395  				return decodeErrCodeCorrupt
 396  			}
 397  			if rOff < 0 {
 398  				if debugErrs {
 399  					fmt.Println("err: START offset", rOff, "less than 0", len(dict.dict), "dict offset:", rOff, "length:", length)
 400  				}
 401  				return decodeErrCodeCorrupt
 402  			}
 403  			copy(dst[d:d+length], dict.dict[rOff:])
 404  			d += length
 405  			continue
 406  		}
 407  
 408  		if debug {
 409  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
 410  		}
 411  
 412  		// Copy from an earlier sub-slice of dst to a later sub-slice.
 413  		// If no overlap, use the built-in copy:
 414  		if offset > length {
 415  			copy(dst[d:d+length], dst[d-offset:])
 416  			d += length
 417  			continue
 418  		}
 419  
 420  		// Unlike the built-in copy function, this byte-by-byte copy always runs
 421  		// forwards, even if the slices overlap. Conceptually, this is:
 422  		//
 423  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
 424  		//
 425  		// We align the slices into a and b and show the compiler they are the same size.
 426  		// This allows the loop to run without bounds checks.
 427  		a := dst[d : d+length]
 428  		b := dst[d-offset:]
 429  		b = b[:len(a)]
 430  		for i := range a {
 431  			a[i] = b[i]
 432  		}
 433  		d += length
 434  	}
 435  
 436  	if d != len(dst) {
 437  		if debugErrs {
 438  			fmt.Println("wanted length", len(dst), "got", d)
 439  		}
 440  		return decodeErrCodeCorrupt
 441  	}
 442  	return 0
 443  }
 444