block.mx raw

   1  // Copyright 2023 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package zstd
   6  
   7  import (
   8  	"io"
   9  )
  10  
  11  // debug can be set in the source to print debug info using println.
  12  const debug = false
  13  
  14  // compressedBlock decompresses a compressed block, storing the decompressed
  15  // data in r.buffer. The blockSize argument is the compressed size.
  16  // RFC 3.1.1.3.
  17  func (r *Reader) compressedBlock(blockSize int) error {
  18  	if len(r.compressedBuf) >= blockSize {
  19  		r.compressedBuf = r.compressedBuf[:blockSize]
  20  	} else {
  21  		// We know that blockSize <= 128K,
  22  		// so this won't allocate an enormous amount.
  23  		need := blockSize - len(r.compressedBuf)
  24  		r.compressedBuf = append(r.compressedBuf, []byte{:need}...)
  25  	}
  26  
  27  	if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
  28  		return r.wrapNonEOFError(0, err)
  29  	}
  30  
  31  	data := block(r.compressedBuf)
  32  	off := 0
  33  	r.buffer = r.buffer[:0]
  34  
  35  	litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
  36  	if err != nil {
  37  		return err
  38  	}
  39  	r.literals = litbuf
  40  
  41  	off = litoff
  42  
  43  	seqCount, off, err := r.initSeqs(data, off)
  44  	if err != nil {
  45  		return err
  46  	}
  47  
  48  	if seqCount == 0 {
  49  		// No sequences, just literals.
  50  		if off < len(data) {
  51  			return r.makeError(off, "extraneous data after no sequences")
  52  		}
  53  
  54  		r.buffer = append(r.buffer, litbuf...)
  55  
  56  		return nil
  57  	}
  58  
  59  	return r.execSeqs(data, off, litbuf, seqCount)
  60  }
  61  
  62  // seqCode is the kind of sequence codes we have to handle.
  63  type seqCode int
  64  
  65  const (
  66  	seqLiteral seqCode = iota
  67  	seqOffset
  68  	seqMatch
  69  )
  70  
  71  // seqCodeInfoData is the information needed to set up seqTables and
  72  // seqTableBits for a particular kind of sequence code.
  73  type seqCodeInfoData struct {
  74  	predefTable     []fseBaselineEntry // predefined FSE
  75  	predefTableBits int                // number of bits in predefTable
  76  	maxSym          int                // max symbol value in FSE
  77  	maxBits         int                // max bits for FSE
  78  
  79  	// toBaseline converts from an FSE table to an FSE baseline table.
  80  	toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
  81  }
  82  
  83  // seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
  84  var seqCodeInfo = [3]seqCodeInfoData{
  85  	seqLiteral: {
  86  		predefTable:     predefinedLiteralTable[:],
  87  		predefTableBits: 6,
  88  		maxSym:          35,
  89  		maxBits:         9,
  90  		toBaseline:      (*Reader).makeLiteralBaselineFSE,
  91  	},
  92  	seqOffset: {
  93  		predefTable:     predefinedOffsetTable[:],
  94  		predefTableBits: 5,
  95  		maxSym:          31,
  96  		maxBits:         8,
  97  		toBaseline:      (*Reader).makeOffsetBaselineFSE,
  98  	},
  99  	seqMatch: {
 100  		predefTable:     predefinedMatchTable[:],
 101  		predefTableBits: 6,
 102  		maxSym:          52,
 103  		maxBits:         9,
 104  		toBaseline:      (*Reader).makeMatchBaselineFSE,
 105  	},
 106  }
 107  
 108  // initSeqs reads the Sequences_Section_Header and sets up the FSE
 109  // tables used to read the sequence codes. It returns the number of
 110  // sequences and the new offset. RFC 3.1.1.3.2.1.
 111  func (r *Reader) initSeqs(data block, off int) (int, int, error) {
 112  	if off >= len(data) {
 113  		return 0, 0, r.makeEOFError(off)
 114  	}
 115  
 116  	seqHdr := data[off]
 117  	off++
 118  	if seqHdr == 0 {
 119  		return 0, off, nil
 120  	}
 121  
 122  	var seqCount int
 123  	if seqHdr < 128 {
 124  		seqCount = int(seqHdr)
 125  	} else if seqHdr < 255 {
 126  		if off >= len(data) {
 127  			return 0, 0, r.makeEOFError(off)
 128  		}
 129  		seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
 130  		off++
 131  	} else {
 132  		if off+1 >= len(data) {
 133  			return 0, 0, r.makeEOFError(off)
 134  		}
 135  		seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
 136  		off += 2
 137  	}
 138  
 139  	// Read the Symbol_Compression_Modes byte.
 140  
 141  	if off >= len(data) {
 142  		return 0, 0, r.makeEOFError(off)
 143  	}
 144  	symMode := data[off]
 145  	if symMode&3 != 0 {
 146  		return 0, 0, r.makeError(off, "invalid symbol compression mode")
 147  	}
 148  	off++
 149  
 150  	// Set up the FSE tables used to decode the sequence codes.
 151  
 152  	var err error
 153  	off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
 154  	if err != nil {
 155  		return 0, 0, err
 156  	}
 157  
 158  	off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
 159  	if err != nil {
 160  		return 0, 0, err
 161  	}
 162  
 163  	off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
 164  	if err != nil {
 165  		return 0, 0, err
 166  	}
 167  
 168  	return seqCount, off, nil
 169  }
 170  
 171  // setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
 172  // r.seqTableBits for kind. We store these in the Reader because one of
 173  // the modes simply reuses the value from the last block in the frame.
 174  func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
 175  	info := &seqCodeInfo[kind]
 176  	switch mode {
 177  	case 0:
 178  		// Predefined_Mode
 179  		r.seqTables[kind] = info.predefTable
 180  		r.seqTableBits[kind] = uint8(info.predefTableBits)
 181  		return off, nil
 182  
 183  	case 1:
 184  		// RLE_Mode
 185  		if off >= len(data) {
 186  			return 0, r.makeEOFError(off)
 187  		}
 188  		rle := data[off]
 189  		off++
 190  
 191  		// Build a simple baseline table that always returns rle.
 192  
 193  		entry := []fseEntry{
 194  			{
 195  				sym:  rle,
 196  				bits: 0,
 197  				base: 0,
 198  			},
 199  		}
 200  		if cap(r.seqTableBuffers[kind]) == 0 {
 201  			r.seqTableBuffers[kind] = []fseBaselineEntry{:1<<info.maxBits}
 202  		}
 203  		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
 204  		if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
 205  			return 0, err
 206  		}
 207  
 208  		r.seqTables[kind] = r.seqTableBuffers[kind]
 209  		r.seqTableBits[kind] = 0
 210  		return off, nil
 211  
 212  	case 2:
 213  		// FSE_Compressed_Mode
 214  		if cap(r.fseScratch) < 1<<info.maxBits {
 215  			r.fseScratch = []fseEntry{:1<<info.maxBits}
 216  		}
 217  		r.fseScratch = r.fseScratch[:1<<info.maxBits]
 218  
 219  		tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
 220  		if err != nil {
 221  			return 0, err
 222  		}
 223  		r.fseScratch = r.fseScratch[:1<<tableBits]
 224  
 225  		if cap(r.seqTableBuffers[kind]) == 0 {
 226  			r.seqTableBuffers[kind] = []fseBaselineEntry{:1<<info.maxBits}
 227  		}
 228  		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
 229  
 230  		if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
 231  			return 0, err
 232  		}
 233  
 234  		r.seqTables[kind] = r.seqTableBuffers[kind]
 235  		r.seqTableBits[kind] = uint8(tableBits)
 236  		return roff, nil
 237  
 238  	case 3:
 239  		// Repeat_Mode
 240  		if len(r.seqTables[kind]) == 0 {
 241  			return 0, r.makeError(off, "missing repeat sequence FSE table")
 242  		}
 243  		return off, nil
 244  	}
 245  	panic("unreachable")
 246  }
 247  
 248  // execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
 249  func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
 250  	// Set up the initial states for the sequence code readers.
 251  
 252  	rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
 253  	if err != nil {
 254  		return err
 255  	}
 256  
 257  	literalState, err := rbr.val(r.seqTableBits[seqLiteral])
 258  	if err != nil {
 259  		return err
 260  	}
 261  
 262  	offsetState, err := rbr.val(r.seqTableBits[seqOffset])
 263  	if err != nil {
 264  		return err
 265  	}
 266  
 267  	matchState, err := rbr.val(r.seqTableBits[seqMatch])
 268  	if err != nil {
 269  		return err
 270  	}
 271  
 272  	// Read and perform all the sequences. RFC 3.1.1.4.
 273  
 274  	seq := 0
 275  	for seq < seqCount {
 276  		if len(r.buffer)+len(litbuf) > 128<<10 {
 277  			return rbr.makeError("uncompressed size too big")
 278  		}
 279  
 280  		ptoffset := &r.seqTables[seqOffset][offsetState]
 281  		ptmatch := &r.seqTables[seqMatch][matchState]
 282  		ptliteral := &r.seqTables[seqLiteral][literalState]
 283  
 284  		add, err := rbr.val(ptoffset.basebits)
 285  		if err != nil {
 286  			return err
 287  		}
 288  		offset := ptoffset.baseline + add
 289  
 290  		add, err = rbr.val(ptmatch.basebits)
 291  		if err != nil {
 292  			return err
 293  		}
 294  		match := ptmatch.baseline + add
 295  
 296  		add, err = rbr.val(ptliteral.basebits)
 297  		if err != nil {
 298  			return err
 299  		}
 300  		literal := ptliteral.baseline + add
 301  
 302  		// Handle repeat offsets. RFC 3.1.1.5.
 303  		// See the comment in makeOffsetBaselineFSE.
 304  		if ptoffset.basebits > 1 {
 305  			r.repeatedOffset3 = r.repeatedOffset2
 306  			r.repeatedOffset2 = r.repeatedOffset1
 307  			r.repeatedOffset1 = offset
 308  		} else {
 309  			if literal == 0 {
 310  				offset++
 311  			}
 312  			switch offset {
 313  			case 1:
 314  				offset = r.repeatedOffset1
 315  			case 2:
 316  				offset = r.repeatedOffset2
 317  				r.repeatedOffset2 = r.repeatedOffset1
 318  				r.repeatedOffset1 = offset
 319  			case 3:
 320  				offset = r.repeatedOffset3
 321  				r.repeatedOffset3 = r.repeatedOffset2
 322  				r.repeatedOffset2 = r.repeatedOffset1
 323  				r.repeatedOffset1 = offset
 324  			case 4:
 325  				offset = r.repeatedOffset1 - 1
 326  				r.repeatedOffset3 = r.repeatedOffset2
 327  				r.repeatedOffset2 = r.repeatedOffset1
 328  				r.repeatedOffset1 = offset
 329  			}
 330  		}
 331  
 332  		seq++
 333  		if seq < seqCount {
 334  			// Update the states.
 335  			add, err = rbr.val(ptliteral.bits)
 336  			if err != nil {
 337  				return err
 338  			}
 339  			literalState = uint32(ptliteral.base) + add
 340  
 341  			add, err = rbr.val(ptmatch.bits)
 342  			if err != nil {
 343  				return err
 344  			}
 345  			matchState = uint32(ptmatch.base) + add
 346  
 347  			add, err = rbr.val(ptoffset.bits)
 348  			if err != nil {
 349  				return err
 350  			}
 351  			offsetState = uint32(ptoffset.base) + add
 352  		}
 353  
 354  		// The next sequence is now in literal, offset, match.
 355  
 356  		if debug {
 357  			println("literal", literal, "offset", offset, "match", match)
 358  		}
 359  
 360  		// Copy literal bytes from litbuf.
 361  		if literal > uint32(len(litbuf)) {
 362  			return rbr.makeError("literal byte overflow")
 363  		}
 364  		if literal > 0 {
 365  			r.buffer = append(r.buffer, litbuf[:literal]...)
 366  			litbuf = litbuf[literal:]
 367  		}
 368  
 369  		if match > 0 {
 370  			if err := r.copyFromWindow(&rbr, offset, match); err != nil {
 371  				return err
 372  			}
 373  		}
 374  	}
 375  
 376  	r.buffer = append(r.buffer, litbuf...)
 377  
 378  	if rbr.cnt != 0 {
 379  		return r.makeError(off, "extraneous data after sequences")
 380  	}
 381  
 382  	return nil
 383  }
 384  
 385  // Copy match bytes from the decoded output, or the window, at offset.
 386  func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
 387  	if offset == 0 {
 388  		return rbr.makeError("invalid zero offset")
 389  	}
 390  
 391  	// Offset may point into the buffer or the window and
 392  	// match may extend past the end of the initial buffer.
 393  	// |--r.window--|--r.buffer--|
 394  	//        |<-----offset------|
 395  	//        |------match----------->|
 396  	bufferOffset := uint32(0)
 397  	lenBlock := uint32(len(r.buffer))
 398  	if lenBlock < offset {
 399  		lenWindow := r.window.len()
 400  		copy := offset - lenBlock
 401  		if copy > lenWindow {
 402  			return rbr.makeError("offset past window")
 403  		}
 404  		windowOffset := lenWindow - copy
 405  		if copy > match {
 406  			copy = match
 407  		}
 408  		r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy)
 409  		match -= copy
 410  	} else {
 411  		bufferOffset = lenBlock - offset
 412  	}
 413  
 414  	// We are being asked to copy data that we are adding to the
 415  	// buffer in the same copy.
 416  	for match > 0 {
 417  		copy := uint32(len(r.buffer)) - bufferOffset
 418  		if copy > match {
 419  			copy = match
 420  		}
 421  		r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...)
 422  		match -= copy
 423  	}
 424  	return nil
 425  }
 426