seqdec.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"errors"
   9  	"fmt"
  10  	"io"
  11  )
  12  
  13  type seq struct {
  14  	litLen   uint32
  15  	matchLen uint32
  16  	offset   uint32
  17  
  18  	// Codes are stored here for the encoder
  19  	// so they only have to be looked up once.
  20  	llCode, mlCode, ofCode uint8
  21  }
  22  
  23  type seqVals struct {
  24  	ll, ml, mo int
  25  }
  26  
  27  func (s seq) String() string {
  28  	if s.offset <= 3 {
  29  		if s.offset == 0 {
  30  			return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
  31  		}
  32  		return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
  33  	}
  34  	return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
  35  }
  36  
  37  type seqCompMode uint8
  38  
  39  const (
  40  	compModePredefined seqCompMode = iota
  41  	compModeRLE
  42  	compModeFSE
  43  	compModeRepeat
  44  )
  45  
  46  type sequenceDec struct {
  47  	// decoder keeps track of the current state and updates it from the bitstream.
  48  	fse    *fseDecoder
  49  	state  fseState
  50  	repeat bool
  51  }
  52  
  53  // init the state of the decoder with input from stream.
  54  func (s *sequenceDec) init(br *bitReader) error {
  55  	if s.fse == nil {
  56  		return errors.New("sequence decoder not defined")
  57  	}
  58  	s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
  59  	return nil
  60  }
  61  
  62  // sequenceDecs contains all 3 sequence decoders and their state.
  63  type sequenceDecs struct {
  64  	litLengths   sequenceDec
  65  	offsets      sequenceDec
  66  	matchLengths sequenceDec
  67  	prevOffset   [3]int
  68  	dict         []byte
  69  	literals     []byte
  70  	out          []byte
  71  	nSeqs        int
  72  	br           *bitReader
  73  	seqSize      int
  74  	windowSize   int
  75  	maxBits      uint8
  76  	maxSyncLen   uint64
  77  }
  78  
  79  // initialize all 3 decoders from the stream input.
  80  func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) error {
  81  	if err := s.litLengths.init(br); err != nil {
  82  		return errors.New("litLengths:" + err.Error())
  83  	}
  84  	if err := s.offsets.init(br); err != nil {
  85  		return errors.New("offsets:" + err.Error())
  86  	}
  87  	if err := s.matchLengths.init(br); err != nil {
  88  		return errors.New("matchLengths:" + err.Error())
  89  	}
  90  	s.br = br
  91  	s.prevOffset = hist.recentOffsets
  92  	s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
  93  	s.windowSize = hist.windowSize
  94  	s.out = out
  95  	s.dict = nil
  96  	if hist.dict != nil {
  97  		s.dict = hist.dict.content
  98  	}
  99  	return nil
 100  }
 101  
 102  func (s *sequenceDecs) freeDecoders() {
 103  	if f := s.litLengths.fse; f != nil && !f.preDefined {
 104  		fseDecoderPool.Put(f)
 105  		s.litLengths.fse = nil
 106  	}
 107  	if f := s.offsets.fse; f != nil && !f.preDefined {
 108  		fseDecoderPool.Put(f)
 109  		s.offsets.fse = nil
 110  	}
 111  	if f := s.matchLengths.fse; f != nil && !f.preDefined {
 112  		fseDecoderPool.Put(f)
 113  		s.matchLengths.fse = nil
 114  	}
 115  }
 116  
 117  // execute will execute the decoded sequence with the provided history.
 118  // The sequence must be evaluated before being sent.
 119  func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
 120  	if len(s.dict) == 0 {
 121  		return s.executeSimple(seqs, hist)
 122  	}
 123  
 124  	// Ensure we have enough output size...
 125  	if len(s.out)+s.seqSize > cap(s.out) {
 126  		addBytes := s.seqSize + len(s.out)
 127  		s.out = append(s.out, make([]byte, addBytes)...)
 128  		s.out = s.out[:len(s.out)-addBytes]
 129  	}
 130  
 131  	if debugDecoder {
 132  		printf("Execute %d seqs with hist %d, dict %d, literals: %d into %d bytes\n", len(seqs), len(hist), len(s.dict), len(s.literals), s.seqSize)
 133  	}
 134  
 135  	var t = len(s.out)
 136  	out := s.out[:t+s.seqSize]
 137  
 138  	for _, seq := range seqs {
 139  		// Add literals
 140  		copy(out[t:], s.literals[:seq.ll])
 141  		t += seq.ll
 142  		s.literals = s.literals[seq.ll:]
 143  
 144  		// Copy from dictionary...
 145  		if seq.mo > t+len(hist) || seq.mo > s.windowSize {
 146  			if len(s.dict) == 0 {
 147  				return fmt.Errorf("match offset (%d) bigger than current history (%d)", seq.mo, t+len(hist))
 148  			}
 149  
 150  			// we may be in dictionary.
 151  			dictO := len(s.dict) - (seq.mo - (t + len(hist)))
 152  			if dictO < 0 || dictO >= len(s.dict) {
 153  				return fmt.Errorf("match offset (%d) bigger than current history+dict (%d)", seq.mo, t+len(hist)+len(s.dict))
 154  			}
 155  			end := dictO + seq.ml
 156  			if end > len(s.dict) {
 157  				n := len(s.dict) - dictO
 158  				copy(out[t:], s.dict[dictO:])
 159  				t += n
 160  				seq.ml -= n
 161  			} else {
 162  				copy(out[t:], s.dict[dictO:end])
 163  				t += end - dictO
 164  				continue
 165  			}
 166  		}
 167  
 168  		// Copy from history.
 169  		if v := seq.mo - t; v > 0 {
 170  			// v is the start position in history from end.
 171  			start := len(hist) - v
 172  			if seq.ml > v {
 173  				// Some goes into current block.
 174  				// Copy remainder of history
 175  				copy(out[t:], hist[start:])
 176  				t += v
 177  				seq.ml -= v
 178  			} else {
 179  				copy(out[t:], hist[start:start+seq.ml])
 180  				t += seq.ml
 181  				continue
 182  			}
 183  		}
 184  		// We must be in current buffer now
 185  		if seq.ml > 0 {
 186  			start := t - seq.mo
 187  			if seq.ml <= t-start {
 188  				// No overlap
 189  				copy(out[t:], out[start:start+seq.ml])
 190  				t += seq.ml
 191  				continue
 192  			} else {
 193  				// Overlapping copy
 194  				// Extend destination slice and copy one byte at the time.
 195  				src := out[start : start+seq.ml]
 196  				dst := out[t:]
 197  				dst = dst[:len(src)]
 198  				t += len(src)
 199  				// Destination is the space we just added.
 200  				for i := range src {
 201  					dst[i] = src[i]
 202  				}
 203  			}
 204  		}
 205  	}
 206  
 207  	// Add final literals
 208  	copy(out[t:], s.literals)
 209  	if debugDecoder {
 210  		t += len(s.literals)
 211  		if t != len(out) {
 212  			panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize))
 213  		}
 214  	}
 215  	s.out = out
 216  
 217  	return nil
 218  }
 219  
 220  // decode sequences from the stream with the provided history.
 221  func (s *sequenceDecs) decodeSync(hist []byte) error {
 222  	supported, err := s.decodeSyncSimple(hist)
 223  	if supported {
 224  		return err
 225  	}
 226  
 227  	br := s.br
 228  	seqs := s.nSeqs
 229  	startSize := len(s.out)
 230  	// Grab full sizes tables, to avoid bounds checks.
 231  	llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
 232  	llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
 233  	out := s.out
 234  	maxBlockSize := min(s.windowSize, maxCompressedBlockSize)
 235  
 236  	if debugDecoder {
 237  		println("decodeSync: decoding", seqs, "sequences", br.remain(), "bits remain on stream")
 238  	}
 239  	for i := seqs - 1; i >= 0; i-- {
 240  		if br.overread() {
 241  			printf("reading sequence %d, exceeded available data. Overread by %d\n", seqs-i, -br.remain())
 242  			return io.ErrUnexpectedEOF
 243  		}
 244  		var ll, mo, ml int
 245  		if br.cursor > 4+((maxOffsetBits+16+16)>>3) {
 246  			// inlined function:
 247  			// ll, mo, ml = s.nextFast(br, llState, mlState, ofState)
 248  
 249  			// Final will not read from stream.
 250  			var llB, mlB, moB uint8
 251  			ll, llB = llState.final()
 252  			ml, mlB = mlState.final()
 253  			mo, moB = ofState.final()
 254  
 255  			// extra bits are stored in reverse order.
 256  			br.fillFast()
 257  			mo += br.getBits(moB)
 258  			if s.maxBits > 32 {
 259  				br.fillFast()
 260  			}
 261  			ml += br.getBits(mlB)
 262  			ll += br.getBits(llB)
 263  
 264  			if moB > 1 {
 265  				s.prevOffset[2] = s.prevOffset[1]
 266  				s.prevOffset[1] = s.prevOffset[0]
 267  				s.prevOffset[0] = mo
 268  			} else {
 269  				// mo = s.adjustOffset(mo, ll, moB)
 270  				// Inlined for rather big speedup
 271  				if ll == 0 {
 272  					// There is an exception though, when current sequence's literals_length = 0.
 273  					// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
 274  					// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
 275  					mo++
 276  				}
 277  
 278  				if mo == 0 {
 279  					mo = s.prevOffset[0]
 280  				} else {
 281  					var temp int
 282  					if mo == 3 {
 283  						temp = s.prevOffset[0] - 1
 284  					} else {
 285  						temp = s.prevOffset[mo]
 286  					}
 287  
 288  					if temp == 0 {
 289  						// 0 is not valid; input is corrupted; force offset to 1
 290  						println("WARNING: temp was 0")
 291  						temp = 1
 292  					}
 293  
 294  					if mo != 1 {
 295  						s.prevOffset[2] = s.prevOffset[1]
 296  					}
 297  					s.prevOffset[1] = s.prevOffset[0]
 298  					s.prevOffset[0] = temp
 299  					mo = temp
 300  				}
 301  			}
 302  			br.fillFast()
 303  		} else {
 304  			ll, mo, ml = s.next(br, llState, mlState, ofState)
 305  			br.fill()
 306  		}
 307  
 308  		if debugSequences {
 309  			println("Seq", seqs-i-1, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml)
 310  		}
 311  
 312  		if ll > len(s.literals) {
 313  			return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, len(s.literals))
 314  		}
 315  		size := ll + ml + len(out)
 316  		if size-startSize > maxBlockSize {
 317  			return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
 318  		}
 319  		if size > cap(out) {
 320  			// Not enough size, which can happen under high volume block streaming conditions
 321  			// but could be if destination slice is too small for sync operations.
 322  			// over-allocating here can create a large amount of GC pressure so we try to keep
 323  			// it as contained as possible
 324  			used := len(out) - startSize
 325  			addBytes := 256 + ll + ml + used>>2
 326  			// Clamp to max block size.
 327  			if used+addBytes > maxBlockSize {
 328  				addBytes = maxBlockSize - used
 329  			}
 330  			out = append(out, make([]byte, addBytes)...)
 331  			out = out[:len(out)-addBytes]
 332  		}
 333  		if ml > maxMatchLen {
 334  			return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
 335  		}
 336  
 337  		// Add literals
 338  		out = append(out, s.literals[:ll]...)
 339  		s.literals = s.literals[ll:]
 340  
 341  		if mo == 0 && ml > 0 {
 342  			return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
 343  		}
 344  
 345  		if mo > len(out)+len(hist) || mo > s.windowSize {
 346  			if len(s.dict) == 0 {
 347  				return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize)
 348  			}
 349  
 350  			// we may be in dictionary.
 351  			dictO := len(s.dict) - (mo - (len(out) + len(hist)))
 352  			if dictO < 0 || dictO >= len(s.dict) {
 353  				return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize)
 354  			}
 355  			end := dictO + ml
 356  			if end > len(s.dict) {
 357  				out = append(out, s.dict[dictO:]...)
 358  				ml -= len(s.dict) - dictO
 359  			} else {
 360  				out = append(out, s.dict[dictO:end]...)
 361  				mo = 0
 362  				ml = 0
 363  			}
 364  		}
 365  
 366  		// Copy from history.
 367  		// TODO: Blocks without history could be made to ignore this completely.
 368  		if v := mo - len(out); v > 0 {
 369  			// v is the start position in history from end.
 370  			start := len(hist) - v
 371  			if ml > v {
 372  				// Some goes into current block.
 373  				// Copy remainder of history
 374  				out = append(out, hist[start:]...)
 375  				ml -= v
 376  			} else {
 377  				out = append(out, hist[start:start+ml]...)
 378  				ml = 0
 379  			}
 380  		}
 381  		// We must be in current buffer now
 382  		if ml > 0 {
 383  			start := len(out) - mo
 384  			if ml <= len(out)-start {
 385  				// No overlap
 386  				out = append(out, out[start:start+ml]...)
 387  			} else {
 388  				// Overlapping copy
 389  				// Extend destination slice and copy one byte at the time.
 390  				out = out[:len(out)+ml]
 391  				src := out[start : start+ml]
 392  				// Destination is the space we just added.
 393  				dst := out[len(out)-ml:]
 394  				dst = dst[:len(src)]
 395  				for i := range src {
 396  					dst[i] = src[i]
 397  				}
 398  			}
 399  		}
 400  		if i == 0 {
 401  			// This is the last sequence, so we shouldn't update state.
 402  			break
 403  		}
 404  
 405  		// Manually inlined, ~ 5-20% faster
 406  		// Update all 3 states at once. Approx 20% faster.
 407  		nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
 408  		if nBits == 0 {
 409  			llState = llTable[llState.newState()&maxTableMask]
 410  			mlState = mlTable[mlState.newState()&maxTableMask]
 411  			ofState = ofTable[ofState.newState()&maxTableMask]
 412  		} else {
 413  			bits := br.get32BitsFast(nBits)
 414  
 415  			lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
 416  			llState = llTable[(llState.newState()+lowBits)&maxTableMask]
 417  
 418  			lowBits = uint16(bits >> (ofState.nbBits() & 31))
 419  			lowBits &= bitMask[mlState.nbBits()&15]
 420  			mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
 421  
 422  			lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
 423  			ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
 424  		}
 425  	}
 426  
 427  	if size := len(s.literals) + len(out) - startSize; size > maxBlockSize {
 428  		return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
 429  	}
 430  
 431  	// Add final literals
 432  	s.out = append(out, s.literals...)
 433  	return br.close()
 434  }
 435  
 436  var bitMask [16]uint16
 437  
 438  func init() {
 439  	for i := range bitMask[:] {
 440  		bitMask[i] = uint16((1 << uint(i)) - 1)
 441  	}
 442  }
 443  
 444  func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
 445  	// Final will not read from stream.
 446  	ll, llB := llState.final()
 447  	ml, mlB := mlState.final()
 448  	mo, moB := ofState.final()
 449  
 450  	// extra bits are stored in reverse order.
 451  	br.fill()
 452  	mo += br.getBits(moB)
 453  	if s.maxBits > 32 {
 454  		br.fill()
 455  	}
 456  	// matchlength+literal length, max 32 bits
 457  	ml += br.getBits(mlB)
 458  	ll += br.getBits(llB)
 459  	mo = s.adjustOffset(mo, ll, moB)
 460  	return
 461  }
 462  
 463  func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
 464  	if offsetB > 1 {
 465  		s.prevOffset[2] = s.prevOffset[1]
 466  		s.prevOffset[1] = s.prevOffset[0]
 467  		s.prevOffset[0] = offset
 468  		return offset
 469  	}
 470  
 471  	if litLen == 0 {
 472  		// There is an exception though, when current sequence's literals_length = 0.
 473  		// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
 474  		// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
 475  		offset++
 476  	}
 477  
 478  	if offset == 0 {
 479  		return s.prevOffset[0]
 480  	}
 481  	var temp int
 482  	if offset == 3 {
 483  		temp = s.prevOffset[0] - 1
 484  	} else {
 485  		temp = s.prevOffset[offset]
 486  	}
 487  
 488  	if temp == 0 {
 489  		// 0 is not valid; input is corrupted; force offset to 1
 490  		println("temp was 0")
 491  		temp = 1
 492  	}
 493  
 494  	if offset != 1 {
 495  		s.prevOffset[2] = s.prevOffset[1]
 496  	}
 497  	s.prevOffset[1] = s.prevOffset[0]
 498  	s.prevOffset[0] = temp
 499  	return temp
 500  }
 501