encoder.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"crypto/rand"
   9  	"errors"
  10  	"fmt"
  11  	"io"
  12  	"math"
  13  	rdebug "runtime/debug"
  14  	"sync"
  15  
  16  	"github.com/klauspost/compress/zstd/internal/xxhash"
  17  )
  18  
  19  // Encoder provides encoding to Zstandard.
  20  // An Encoder can be used for either compressing a stream via the
  21  // io.WriteCloser interface supported by the Encoder or as multiple independent
  22  // tasks via the EncodeAll function.
  23  // Smaller encodes are encouraged to use the EncodeAll function.
  24  // Use NewWriter to create a new instance.
  25  type Encoder struct {
  26  	o        encoderOptions
  27  	encoders chan encoder
  28  	state    encoderState
  29  	init     sync.Once
  30  }
  31  
  32  type encoder interface {
  33  	Encode(blk *blockEnc, src []byte)
  34  	EncodeNoHist(blk *blockEnc, src []byte)
  35  	Block() *blockEnc
  36  	CRC() *xxhash.Digest
  37  	AppendCRC([]byte) []byte
  38  	WindowSize(size int64) int32
  39  	UseBlock(*blockEnc)
  40  	Reset(d *dict, singleBlock bool)
  41  }
  42  
  43  type encoderState struct {
  44  	w                io.Writer
  45  	filling          []byte
  46  	current          []byte
  47  	previous         []byte
  48  	encoder          encoder
  49  	writing          *blockEnc
  50  	err              error
  51  	writeErr         error
  52  	nWritten         int64
  53  	nInput           int64
  54  	frameContentSize int64
  55  	headerWritten    bool
  56  	eofWritten       bool
  57  	fullFrameWritten bool
  58  
  59  	// This waitgroup indicates an encode is running.
  60  	wg sync.WaitGroup
  61  	// This waitgroup indicates we have a block encoding/writing.
  62  	wWg sync.WaitGroup
  63  }
  64  
  65  // NewWriter will create a new Zstandard encoder.
  66  // If the encoder will be used for encoding blocks a nil writer can be used.
  67  func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  68  	initPredefined()
  69  	var e Encoder
  70  	e.o.setDefault()
  71  	for _, o := range opts {
  72  		err := o(&e.o)
  73  		if err != nil {
  74  			return nil, err
  75  		}
  76  	}
  77  	if w != nil {
  78  		e.Reset(w)
  79  	}
  80  	return &e, nil
  81  }
  82  
  83  func (e *Encoder) initialize() {
  84  	if e.o.concurrent == 0 {
  85  		e.o.setDefault()
  86  	}
  87  	e.encoders = make(chan encoder, e.o.concurrent)
  88  	for i := 0; i < e.o.concurrent; i++ {
  89  		enc := e.o.encoder()
  90  		e.encoders <- enc
  91  	}
  92  }
  93  
  94  // Reset will re-initialize the writer and new writes will encode to the supplied writer
  95  // as a new, independent stream.
  96  func (e *Encoder) Reset(w io.Writer) {
  97  	s := &e.state
  98  	s.wg.Wait()
  99  	s.wWg.Wait()
 100  	if cap(s.filling) == 0 {
 101  		s.filling = make([]byte, 0, e.o.blockSize)
 102  	}
 103  	if e.o.concurrent > 1 {
 104  		if cap(s.current) == 0 {
 105  			s.current = make([]byte, 0, e.o.blockSize)
 106  		}
 107  		if cap(s.previous) == 0 {
 108  			s.previous = make([]byte, 0, e.o.blockSize)
 109  		}
 110  		s.current = s.current[:0]
 111  		s.previous = s.previous[:0]
 112  		if s.writing == nil {
 113  			s.writing = &blockEnc{lowMem: e.o.lowMem}
 114  			s.writing.init()
 115  		}
 116  		s.writing.initNewEncode()
 117  	}
 118  	if s.encoder == nil {
 119  		s.encoder = e.o.encoder()
 120  	}
 121  	s.filling = s.filling[:0]
 122  	s.encoder.Reset(e.o.dict, false)
 123  	s.headerWritten = false
 124  	s.eofWritten = false
 125  	s.fullFrameWritten = false
 126  	s.w = w
 127  	s.err = nil
 128  	s.nWritten = 0
 129  	s.nInput = 0
 130  	s.writeErr = nil
 131  	s.frameContentSize = 0
 132  }
 133  
 134  // ResetContentSize will reset and set a content size for the next stream.
 135  // If the bytes written does not match the size given an error will be returned
 136  // when calling Close().
 137  // This is removed when Reset is called.
 138  // Sizes <= 0 results in no content size set.
 139  func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
 140  	e.Reset(w)
 141  	if size >= 0 {
 142  		e.state.frameContentSize = size
 143  	}
 144  }
 145  
 146  // Write data to the encoder.
 147  // Input data will be buffered and as the buffer fills up
 148  // content will be compressed and written to the output.
 149  // When done writing, use Close to flush the remaining output
 150  // and write CRC if requested.
 151  func (e *Encoder) Write(p []byte) (n int, err error) {
 152  	s := &e.state
 153  	if s.eofWritten {
 154  		return 0, ErrEncoderClosed
 155  	}
 156  	for len(p) > 0 {
 157  		if len(p)+len(s.filling) < e.o.blockSize {
 158  			if e.o.crc {
 159  				_, _ = s.encoder.CRC().Write(p)
 160  			}
 161  			s.filling = append(s.filling, p...)
 162  			return n + len(p), nil
 163  		}
 164  		add := p
 165  		if len(p)+len(s.filling) > e.o.blockSize {
 166  			add = add[:e.o.blockSize-len(s.filling)]
 167  		}
 168  		if e.o.crc {
 169  			_, _ = s.encoder.CRC().Write(add)
 170  		}
 171  		s.filling = append(s.filling, add...)
 172  		p = p[len(add):]
 173  		n += len(add)
 174  		if len(s.filling) < e.o.blockSize {
 175  			return n, nil
 176  		}
 177  		err := e.nextBlock(false)
 178  		if err != nil {
 179  			return n, err
 180  		}
 181  		if debugAsserts && len(s.filling) > 0 {
 182  			panic(len(s.filling))
 183  		}
 184  	}
 185  	return n, nil
 186  }
 187  
 188  // nextBlock will synchronize and start compressing input in e.state.filling.
 189  // If an error has occurred during encoding it will be returned.
 190  func (e *Encoder) nextBlock(final bool) error {
 191  	s := &e.state
 192  	// Wait for current block.
 193  	s.wg.Wait()
 194  	if s.err != nil {
 195  		return s.err
 196  	}
 197  	if len(s.filling) > e.o.blockSize {
 198  		return fmt.Errorf("block > maxStoreBlockSize")
 199  	}
 200  	if !s.headerWritten {
 201  		// If we have a single block encode, do a sync compression.
 202  		if final && len(s.filling) == 0 && !e.o.fullZero {
 203  			s.headerWritten = true
 204  			s.fullFrameWritten = true
 205  			s.eofWritten = true
 206  			return nil
 207  		}
 208  		if final && len(s.filling) > 0 {
 209  			s.current = e.encodeAll(s.encoder, s.filling, s.current[:0])
 210  			var n2 int
 211  			n2, s.err = s.w.Write(s.current)
 212  			if s.err != nil {
 213  				return s.err
 214  			}
 215  			s.nWritten += int64(n2)
 216  			s.nInput += int64(len(s.filling))
 217  			s.current = s.current[:0]
 218  			s.filling = s.filling[:0]
 219  			s.headerWritten = true
 220  			s.fullFrameWritten = true
 221  			s.eofWritten = true
 222  			return nil
 223  		}
 224  
 225  		var tmp [maxHeaderSize]byte
 226  		fh := frameHeader{
 227  			ContentSize:   uint64(s.frameContentSize),
 228  			WindowSize:    uint32(s.encoder.WindowSize(s.frameContentSize)),
 229  			SingleSegment: false,
 230  			Checksum:      e.o.crc,
 231  			DictID:        e.o.dict.ID(),
 232  		}
 233  
 234  		dst := fh.appendTo(tmp[:0])
 235  		s.headerWritten = true
 236  		s.wWg.Wait()
 237  		var n2 int
 238  		n2, s.err = s.w.Write(dst)
 239  		if s.err != nil {
 240  			return s.err
 241  		}
 242  		s.nWritten += int64(n2)
 243  	}
 244  	if s.eofWritten {
 245  		// Ensure we only write it once.
 246  		final = false
 247  	}
 248  
 249  	if len(s.filling) == 0 {
 250  		// Final block, but no data.
 251  		if final {
 252  			enc := s.encoder
 253  			blk := enc.Block()
 254  			blk.reset(nil)
 255  			blk.last = true
 256  			blk.encodeRaw(nil)
 257  			s.wWg.Wait()
 258  			_, s.err = s.w.Write(blk.output)
 259  			s.nWritten += int64(len(blk.output))
 260  			s.eofWritten = true
 261  		}
 262  		return s.err
 263  	}
 264  
 265  	// SYNC:
 266  	if e.o.concurrent == 1 {
 267  		src := s.filling
 268  		s.nInput += int64(len(s.filling))
 269  		if debugEncoder {
 270  			println("Adding sync block,", len(src), "bytes, final:", final)
 271  		}
 272  		enc := s.encoder
 273  		blk := enc.Block()
 274  		blk.reset(nil)
 275  		enc.Encode(blk, src)
 276  		blk.last = final
 277  		if final {
 278  			s.eofWritten = true
 279  		}
 280  
 281  		s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
 282  		if s.err != nil {
 283  			return s.err
 284  		}
 285  		_, s.err = s.w.Write(blk.output)
 286  		s.nWritten += int64(len(blk.output))
 287  		s.filling = s.filling[:0]
 288  		return s.err
 289  	}
 290  
 291  	// Move blocks forward.
 292  	s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
 293  	s.nInput += int64(len(s.current))
 294  	s.wg.Add(1)
 295  	if final {
 296  		s.eofWritten = true
 297  	}
 298  	go func(src []byte) {
 299  		if debugEncoder {
 300  			println("Adding block,", len(src), "bytes, final:", final)
 301  		}
 302  		defer func() {
 303  			if r := recover(); r != nil {
 304  				s.err = fmt.Errorf("panic while encoding: %v", r)
 305  				rdebug.PrintStack()
 306  			}
 307  			s.wg.Done()
 308  		}()
 309  		enc := s.encoder
 310  		blk := enc.Block()
 311  		enc.Encode(blk, src)
 312  		blk.last = final
 313  		// Wait for pending writes.
 314  		s.wWg.Wait()
 315  		if s.writeErr != nil {
 316  			s.err = s.writeErr
 317  			return
 318  		}
 319  		// Transfer encoders from previous write block.
 320  		blk.swapEncoders(s.writing)
 321  		// Transfer recent offsets to next.
 322  		enc.UseBlock(s.writing)
 323  		s.writing = blk
 324  		s.wWg.Add(1)
 325  		go func() {
 326  			defer func() {
 327  				if r := recover(); r != nil {
 328  					s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
 329  					rdebug.PrintStack()
 330  				}
 331  				s.wWg.Done()
 332  			}()
 333  			s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
 334  			if s.writeErr != nil {
 335  				return
 336  			}
 337  			_, s.writeErr = s.w.Write(blk.output)
 338  			s.nWritten += int64(len(blk.output))
 339  		}()
 340  	}(s.current)
 341  	return nil
 342  }
 343  
 344  // ReadFrom reads data from r until EOF or error.
 345  // The return value n is the number of bytes read.
 346  // Any error except io.EOF encountered during the read is also returned.
 347  //
 348  // The Copy function uses ReaderFrom if available.
 349  func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
 350  	if debugEncoder {
 351  		println("Using ReadFrom")
 352  	}
 353  
 354  	// Flush any current writes.
 355  	if len(e.state.filling) > 0 {
 356  		if err := e.nextBlock(false); err != nil {
 357  			return 0, err
 358  		}
 359  	}
 360  	e.state.filling = e.state.filling[:e.o.blockSize]
 361  	src := e.state.filling
 362  	for {
 363  		n2, err := r.Read(src)
 364  		if e.o.crc {
 365  			_, _ = e.state.encoder.CRC().Write(src[:n2])
 366  		}
 367  		// src is now the unfilled part...
 368  		src = src[n2:]
 369  		n += int64(n2)
 370  		switch err {
 371  		case io.EOF:
 372  			e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
 373  			if debugEncoder {
 374  				println("ReadFrom: got EOF final block:", len(e.state.filling))
 375  			}
 376  			return n, nil
 377  		case nil:
 378  		default:
 379  			if debugEncoder {
 380  				println("ReadFrom: got error:", err)
 381  			}
 382  			e.state.err = err
 383  			return n, err
 384  		}
 385  		if len(src) > 0 {
 386  			if debugEncoder {
 387  				println("ReadFrom: got space left in source:", len(src))
 388  			}
 389  			continue
 390  		}
 391  		err = e.nextBlock(false)
 392  		if err != nil {
 393  			return n, err
 394  		}
 395  		e.state.filling = e.state.filling[:e.o.blockSize]
 396  		src = e.state.filling
 397  	}
 398  }
 399  
 400  // Flush will send the currently written data to output
 401  // and block until everything has been written.
 402  // This should only be used on rare occasions where pushing the currently queued data is critical.
 403  func (e *Encoder) Flush() error {
 404  	s := &e.state
 405  	if len(s.filling) > 0 {
 406  		err := e.nextBlock(false)
 407  		if err != nil {
 408  			// Ignore Flush after Close.
 409  			if errors.Is(s.err, ErrEncoderClosed) {
 410  				return nil
 411  			}
 412  			return err
 413  		}
 414  	}
 415  	s.wg.Wait()
 416  	s.wWg.Wait()
 417  	if s.err != nil {
 418  		// Ignore Flush after Close.
 419  		if errors.Is(s.err, ErrEncoderClosed) {
 420  			return nil
 421  		}
 422  		return s.err
 423  	}
 424  	return s.writeErr
 425  }
 426  
 427  // Close will flush the final output and close the stream.
 428  // The function will block until everything has been written.
 429  // The Encoder can still be re-used after calling this.
 430  func (e *Encoder) Close() error {
 431  	s := &e.state
 432  	if s.encoder == nil {
 433  		return nil
 434  	}
 435  	err := e.nextBlock(true)
 436  	if err != nil {
 437  		if errors.Is(s.err, ErrEncoderClosed) {
 438  			return nil
 439  		}
 440  		return err
 441  	}
 442  	if s.frameContentSize > 0 {
 443  		if s.nInput != s.frameContentSize {
 444  			return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
 445  		}
 446  	}
 447  	if e.state.fullFrameWritten {
 448  		return s.err
 449  	}
 450  	s.wg.Wait()
 451  	s.wWg.Wait()
 452  
 453  	if s.err != nil {
 454  		return s.err
 455  	}
 456  	if s.writeErr != nil {
 457  		return s.writeErr
 458  	}
 459  
 460  	// Write CRC
 461  	if e.o.crc && s.err == nil {
 462  		// heap alloc.
 463  		var tmp [4]byte
 464  		_, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
 465  		s.nWritten += 4
 466  	}
 467  
 468  	// Add padding with content from crypto/rand.Reader
 469  	if s.err == nil && e.o.pad > 0 {
 470  		add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
 471  		frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
 472  		if err != nil {
 473  			return err
 474  		}
 475  		_, s.err = s.w.Write(frame)
 476  	}
 477  	if s.err == nil {
 478  		s.err = ErrEncoderClosed
 479  		return nil
 480  	}
 481  
 482  	return s.err
 483  }
 484  
 485  // EncodeAll will encode all input in src and append it to dst.
 486  // This function can be called concurrently, but each call will only run on a single goroutine.
 487  // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
 488  // Encoded blocks can be concatenated and the result will be the combined input stream.
 489  // Data compressed with EncodeAll can be decoded with the Decoder,
 490  // using either a stream or DecodeAll.
 491  func (e *Encoder) EncodeAll(src, dst []byte) []byte {
 492  	e.init.Do(e.initialize)
 493  	enc := <-e.encoders
 494  	defer func() {
 495  		e.encoders <- enc
 496  	}()
 497  	return e.encodeAll(enc, src, dst)
 498  }
 499  
 500  func (e *Encoder) encodeAll(enc encoder, src, dst []byte) []byte {
 501  	if len(src) == 0 {
 502  		if e.o.fullZero {
 503  			// Add frame header.
 504  			fh := frameHeader{
 505  				ContentSize:   0,
 506  				WindowSize:    MinWindowSize,
 507  				SingleSegment: true,
 508  				// Adding a checksum would be a waste of space.
 509  				Checksum: false,
 510  				DictID:   0,
 511  			}
 512  			dst = fh.appendTo(dst)
 513  
 514  			// Write raw block as last one only.
 515  			var blk blockHeader
 516  			blk.setSize(0)
 517  			blk.setType(blockTypeRaw)
 518  			blk.setLast(true)
 519  			dst = blk.appendTo(dst)
 520  		}
 521  		return dst
 522  	}
 523  
 524  	// Use single segments when above minimum window and below window size.
 525  	single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
 526  	if e.o.single != nil {
 527  		single = *e.o.single
 528  	}
 529  	fh := frameHeader{
 530  		ContentSize:   uint64(len(src)),
 531  		WindowSize:    uint32(enc.WindowSize(int64(len(src)))),
 532  		SingleSegment: single,
 533  		Checksum:      e.o.crc,
 534  		DictID:        e.o.dict.ID(),
 535  	}
 536  
 537  	// If less than 1MB, allocate a buffer up front.
 538  	if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
 539  		dst = make([]byte, 0, len(src))
 540  	}
 541  	dst = fh.appendTo(dst)
 542  
 543  	// If we can do everything in one block, prefer that.
 544  	if len(src) <= e.o.blockSize {
 545  		enc.Reset(e.o.dict, true)
 546  		// Slightly faster with no history and everything in one block.
 547  		if e.o.crc {
 548  			_, _ = enc.CRC().Write(src)
 549  		}
 550  		blk := enc.Block()
 551  		blk.last = true
 552  		if e.o.dict == nil {
 553  			enc.EncodeNoHist(blk, src)
 554  		} else {
 555  			enc.Encode(blk, src)
 556  		}
 557  
 558  		// If we got the exact same number of literals as input,
 559  		// assume the literals cannot be compressed.
 560  		oldout := blk.output
 561  		// Output directly to dst
 562  		blk.output = dst
 563  
 564  		err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
 565  		if err != nil {
 566  			panic(err)
 567  		}
 568  		dst = blk.output
 569  		blk.output = oldout
 570  	} else {
 571  		enc.Reset(e.o.dict, false)
 572  		blk := enc.Block()
 573  		for len(src) > 0 {
 574  			todo := src
 575  			if len(todo) > e.o.blockSize {
 576  				todo = todo[:e.o.blockSize]
 577  			}
 578  			src = src[len(todo):]
 579  			if e.o.crc {
 580  				_, _ = enc.CRC().Write(todo)
 581  			}
 582  			blk.pushOffsets()
 583  			enc.Encode(blk, todo)
 584  			if len(src) == 0 {
 585  				blk.last = true
 586  			}
 587  			err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
 588  			if err != nil {
 589  				panic(err)
 590  			}
 591  			dst = append(dst, blk.output...)
 592  			blk.reset(nil)
 593  		}
 594  	}
 595  	if e.o.crc {
 596  		dst = enc.AppendCRC(dst)
 597  	}
 598  	// Add padding with content from crypto/rand.Reader
 599  	if e.o.pad > 0 {
 600  		add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
 601  		var err error
 602  		dst, err = skippableFrame(dst, add, rand.Reader)
 603  		if err != nil {
 604  			panic(err)
 605  		}
 606  	}
 607  	return dst
 608  }
 609  
 610  // MaxEncodedSize returns the expected maximum
 611  // size of an encoded block or stream.
 612  func (e *Encoder) MaxEncodedSize(size int) int {
 613  	frameHeader := 4 + 2 // magic + frame header & window descriptor
 614  	if e.o.dict != nil {
 615  		frameHeader += 4
 616  	}
 617  	// Frame content size:
 618  	if size < 256 {
 619  		frameHeader++
 620  	} else if size < 65536+256 {
 621  		frameHeader += 2
 622  	} else if size < math.MaxInt32 {
 623  		frameHeader += 4
 624  	} else {
 625  		frameHeader += 8
 626  	}
 627  	// Final crc
 628  	if e.o.crc {
 629  		frameHeader += 4
 630  	}
 631  
 632  	// Max overhead is 3 bytes/block.
 633  	// There cannot be 0 blocks.
 634  	blocks := (size + e.o.blockSize) / e.o.blockSize
 635  
 636  	// Combine, add padding.
 637  	maxSz := frameHeader + 3*blocks + size
 638  	if e.o.pad > 1 {
 639  		maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
 640  	}
 641  	return maxSz
 642  }
 643