decode_other.go raw

   1  // Copyright 2016 The Snappy-Go Authors. All rights reserved.
   2  // Copyright (c) 2019 Klaus Post. All rights reserved.
   3  // Use of this source code is governed by a BSD-style
   4  // license that can be found in the LICENSE file.
   5  
   6  //go:build (!amd64 && !arm64) || appengine || !gc || noasm
   7  // +build !amd64,!arm64 appengine !gc noasm
   8  
   9  package s2
  10  
  11  import (
  12  	"fmt"
  13  	"strconv"
  14  
  15  	"github.com/klauspost/compress/internal/le"
  16  )
  17  
  18  // decode writes the decoding of src to dst. It assumes that the varint-encoded
  19  // length of the decompressed bytes has already been read, and that len(dst)
  20  // equals that length.
  21  //
  22  // It returns 0 on success or a decodeErrCodeXxx error code on failure.
  23  func s2Decode(dst, src []byte) int {
  24  	const debug = false
  25  	if debug {
  26  		fmt.Println("Starting decode, dst len:", len(dst))
  27  	}
  28  	var d, s, length int
  29  	offset := 0
  30  
  31  	// As long as we can read at least 5 bytes...
  32  	for s < len(src)-5 {
  33  		// Removing bounds checks is SLOWER, when if doing
  34  		// in := src[s:s+5]
  35  		// Checked on Go 1.18
  36  		switch src[s] & 0x03 {
  37  		case tagLiteral:
  38  			x := uint32(src[s] >> 2)
  39  			switch {
  40  			case x < 60:
  41  				s++
  42  			case x == 60:
  43  				x = uint32(src[s+1])
  44  				s += 2
  45  			case x == 61:
  46  				x = uint32(le.Load16(src, s+1))
  47  				s += 3
  48  			case x == 62:
  49  				// Load as 32 bit and shift down.
  50  				x = le.Load32(src, s)
  51  				x >>= 8
  52  				s += 4
  53  			case x == 63:
  54  				x = le.Load32(src, s+1)
  55  				s += 5
  56  			}
  57  			length = int(x) + 1
  58  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
  59  				if debug {
  60  					fmt.Println("corrupt: lit size", length)
  61  				}
  62  				return decodeErrCodeCorrupt
  63  			}
  64  			if debug {
  65  				fmt.Println("literals, length:", length, "d-after:", d+length)
  66  			}
  67  
  68  			copy(dst[d:], src[s:s+length])
  69  			d += length
  70  			s += length
  71  			continue
  72  
  73  		case tagCopy1:
  74  			s += 2
  75  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
  76  			length = int(src[s-2]) >> 2 & 0x7
  77  			if toffset == 0 {
  78  				if debug {
  79  					fmt.Print("(repeat) ")
  80  				}
  81  				// keep last offset
  82  				switch length {
  83  				case 5:
  84  					length = int(src[s]) + 4
  85  					s += 1
  86  				case 6:
  87  					length = int(le.Load16(src, s)) + 1<<8
  88  					s += 2
  89  				case 7:
  90  					in := src[s : s+3]
  91  					length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16)
  92  					s += 3
  93  				default: // 0-> 4
  94  				}
  95  			} else {
  96  				offset = toffset
  97  			}
  98  			length += 4
  99  		case tagCopy2:
 100  			offset = int(le.Load16(src, s+1))
 101  			length = 1 + int(src[s])>>2
 102  			s += 3
 103  
 104  		case tagCopy4:
 105  			offset = int(le.Load32(src, s+1))
 106  			length = 1 + int(src[s])>>2
 107  			s += 5
 108  		}
 109  
 110  		if offset <= 0 || d < offset || length > len(dst)-d {
 111  			if debug {
 112  				fmt.Println("corrupt: match, length", length, "offset:", offset, "dst avail:", len(dst)-d, "dst pos:", d)
 113  			}
 114  
 115  			return decodeErrCodeCorrupt
 116  		}
 117  
 118  		if debug {
 119  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
 120  		}
 121  
 122  		// Copy from an earlier sub-slice of dst to a later sub-slice.
 123  		// If no overlap, use the built-in copy:
 124  		if offset > length {
 125  			copy(dst[d:d+length], dst[d-offset:])
 126  			d += length
 127  			continue
 128  		}
 129  
 130  		// Unlike the built-in copy function, this byte-by-byte copy always runs
 131  		// forwards, even if the slices overlap. Conceptually, this is:
 132  		//
 133  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
 134  		//
 135  		// We align the slices into a and b and show the compiler they are the same size.
 136  		// This allows the loop to run without bounds checks.
 137  		a := dst[d : d+length]
 138  		b := dst[d-offset:]
 139  		b = b[:len(a)]
 140  		for i := range a {
 141  			a[i] = b[i]
 142  		}
 143  		d += length
 144  	}
 145  
 146  	// Remaining with extra checks...
 147  	for s < len(src) {
 148  		switch src[s] & 0x03 {
 149  		case tagLiteral:
 150  			x := uint32(src[s] >> 2)
 151  			switch {
 152  			case x < 60:
 153  				s++
 154  			case x == 60:
 155  				s += 2
 156  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 157  					return decodeErrCodeCorrupt
 158  				}
 159  				x = uint32(src[s-1])
 160  			case x == 61:
 161  				s += 3
 162  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 163  					return decodeErrCodeCorrupt
 164  				}
 165  				x = uint32(src[s-2]) | uint32(src[s-1])<<8
 166  			case x == 62:
 167  				s += 4
 168  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 169  					return decodeErrCodeCorrupt
 170  				}
 171  				x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
 172  			case x == 63:
 173  				s += 5
 174  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 175  					return decodeErrCodeCorrupt
 176  				}
 177  				x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
 178  			}
 179  			length = int(x) + 1
 180  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
 181  				if debug {
 182  					fmt.Println("corrupt: lit size", length)
 183  				}
 184  				return decodeErrCodeCorrupt
 185  			}
 186  			if debug {
 187  				fmt.Println("literals, length:", length, "d-after:", d+length)
 188  			}
 189  
 190  			copy(dst[d:], src[s:s+length])
 191  			d += length
 192  			s += length
 193  			continue
 194  
 195  		case tagCopy1:
 196  			s += 2
 197  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 198  				return decodeErrCodeCorrupt
 199  			}
 200  			length = int(src[s-2]) >> 2 & 0x7
 201  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
 202  			if toffset == 0 {
 203  				if debug {
 204  					fmt.Print("(repeat) ")
 205  				}
 206  				// keep last offset
 207  				switch length {
 208  				case 5:
 209  					s += 1
 210  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 211  						return decodeErrCodeCorrupt
 212  					}
 213  					length = int(uint32(src[s-1])) + 4
 214  				case 6:
 215  					s += 2
 216  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 217  						return decodeErrCodeCorrupt
 218  					}
 219  					length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8)
 220  				case 7:
 221  					s += 3
 222  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 223  						return decodeErrCodeCorrupt
 224  					}
 225  					length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16)
 226  				default: // 0-> 4
 227  				}
 228  			} else {
 229  				offset = toffset
 230  			}
 231  			length += 4
 232  		case tagCopy2:
 233  			s += 3
 234  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 235  				return decodeErrCodeCorrupt
 236  			}
 237  			length = 1 + int(src[s-3])>>2
 238  			offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
 239  
 240  		case tagCopy4:
 241  			s += 5
 242  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
 243  				return decodeErrCodeCorrupt
 244  			}
 245  			length = 1 + int(src[s-5])>>2
 246  			offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
 247  		}
 248  
 249  		if offset <= 0 || d < offset || length > len(dst)-d {
 250  			if debug {
 251  				fmt.Println("corrupt: match, length", length, "offset:", offset, "dst avail:", len(dst)-d, "dst pos:", d)
 252  			}
 253  			return decodeErrCodeCorrupt
 254  		}
 255  
 256  		if debug {
 257  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
 258  		}
 259  
 260  		// Copy from an earlier sub-slice of dst to a later sub-slice.
 261  		// If no overlap, use the built-in copy:
 262  		if offset > length {
 263  			copy(dst[d:d+length], dst[d-offset:])
 264  			d += length
 265  			continue
 266  		}
 267  
 268  		// Unlike the built-in copy function, this byte-by-byte copy always runs
 269  		// forwards, even if the slices overlap. Conceptually, this is:
 270  		//
 271  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
 272  		//
 273  		// We align the slices into a and b and show the compiler they are the same size.
 274  		// This allows the loop to run without bounds checks.
 275  		a := dst[d : d+length]
 276  		b := dst[d-offset:]
 277  		b = b[:len(a)]
 278  		for i := range a {
 279  			a[i] = b[i]
 280  		}
 281  		d += length
 282  	}
 283  
 284  	if d != len(dst) {
 285  		return decodeErrCodeCorrupt
 286  	}
 287  	return 0
 288  }
 289