dict.go raw

   1  package zstd
   2  
   3  import (
   4  	"bytes"
   5  	"encoding/binary"
   6  	"errors"
   7  	"fmt"
   8  	"io"
   9  	"math"
  10  	"sort"
  11  
  12  	"github.com/klauspost/compress/huff0"
  13  )
  14  
  15  type dict struct {
  16  	id uint32
  17  
  18  	litEnc              *huff0.Scratch
  19  	llDec, ofDec, mlDec sequenceDec
  20  	offsets             [3]int
  21  	content             []byte
  22  }
  23  
  24  const dictMagic = "\x37\xa4\x30\xec"
  25  
  26  // Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
  27  const dictMaxLength = 1 << 31
  28  
  29  // ID returns the dictionary id or 0 if d is nil.
  30  func (d *dict) ID() uint32 {
  31  	if d == nil {
  32  		return 0
  33  	}
  34  	return d.id
  35  }
  36  
  37  // ContentSize returns the dictionary content size or 0 if d is nil.
  38  func (d *dict) ContentSize() int {
  39  	if d == nil {
  40  		return 0
  41  	}
  42  	return len(d.content)
  43  }
  44  
  45  // Content returns the dictionary content.
  46  func (d *dict) Content() []byte {
  47  	if d == nil {
  48  		return nil
  49  	}
  50  	return d.content
  51  }
  52  
  53  // Offsets returns the initial offsets.
  54  func (d *dict) Offsets() [3]int {
  55  	if d == nil {
  56  		return [3]int{}
  57  	}
  58  	return d.offsets
  59  }
  60  
  61  // LitEncoder returns the literal encoder.
  62  func (d *dict) LitEncoder() *huff0.Scratch {
  63  	if d == nil {
  64  		return nil
  65  	}
  66  	return d.litEnc
  67  }
  68  
  69  // Load a dictionary as described in
  70  // https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
  71  func loadDict(b []byte) (*dict, error) {
  72  	// Check static field size.
  73  	if len(b) <= 8+(3*4) {
  74  		return nil, io.ErrUnexpectedEOF
  75  	}
  76  	d := dict{
  77  		llDec: sequenceDec{fse: &fseDecoder{}},
  78  		ofDec: sequenceDec{fse: &fseDecoder{}},
  79  		mlDec: sequenceDec{fse: &fseDecoder{}},
  80  	}
  81  	if string(b[:4]) != dictMagic {
  82  		return nil, ErrMagicMismatch
  83  	}
  84  	d.id = binary.LittleEndian.Uint32(b[4:8])
  85  	if d.id == 0 {
  86  		return nil, errors.New("dictionaries cannot have ID 0")
  87  	}
  88  
  89  	// Read literal table
  90  	var err error
  91  	d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
  92  	if err != nil {
  93  		return nil, fmt.Errorf("loading literal table: %w", err)
  94  	}
  95  	d.litEnc.Reuse = huff0.ReusePolicyMust
  96  
  97  	br := byteReader{
  98  		b:   b,
  99  		off: 0,
 100  	}
 101  	readDec := func(i tableIndex, dec *fseDecoder) error {
 102  		if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
 103  			return err
 104  		}
 105  		if br.overread() {
 106  			return io.ErrUnexpectedEOF
 107  		}
 108  		err = dec.transform(symbolTableX[i])
 109  		if err != nil {
 110  			println("Transform table error:", err)
 111  			return err
 112  		}
 113  		if debugDecoder || debugEncoder {
 114  			println("Read table ok", "symbolLen:", dec.symbolLen)
 115  		}
 116  		// Set decoders as predefined so they aren't reused.
 117  		dec.preDefined = true
 118  		return nil
 119  	}
 120  
 121  	if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
 122  		return nil, err
 123  	}
 124  	if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
 125  		return nil, err
 126  	}
 127  	if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
 128  		return nil, err
 129  	}
 130  	if br.remain() < 12 {
 131  		return nil, io.ErrUnexpectedEOF
 132  	}
 133  
 134  	d.offsets[0] = int(br.Uint32())
 135  	br.advance(4)
 136  	d.offsets[1] = int(br.Uint32())
 137  	br.advance(4)
 138  	d.offsets[2] = int(br.Uint32())
 139  	br.advance(4)
 140  	if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
 141  		return nil, errors.New("invalid offset in dictionary")
 142  	}
 143  	d.content = make([]byte, br.remain())
 144  	copy(d.content, br.unread())
 145  	if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
 146  		return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
 147  	}
 148  
 149  	return &d, nil
 150  }
 151  
 152  // InspectDictionary loads a zstd dictionary and provides functions to inspect the content.
 153  func InspectDictionary(b []byte) (interface {
 154  	ID() uint32
 155  	ContentSize() int
 156  	Content() []byte
 157  	Offsets() [3]int
 158  	LitEncoder() *huff0.Scratch
 159  }, error) {
 160  	initPredefined()
 161  	d, err := loadDict(b)
 162  	return d, err
 163  }
 164  
 165  type BuildDictOptions struct {
 166  	// Dictionary ID.
 167  	ID uint32
 168  
 169  	// Content to use to create dictionary tables.
 170  	Contents [][]byte
 171  
 172  	// History to use for all blocks.
 173  	History []byte
 174  
 175  	// Offsets to use.
 176  	Offsets [3]int
 177  
 178  	// CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier.
 179  	// See https://github.com/facebook/zstd/issues/3724
 180  	CompatV155 bool
 181  
 182  	// Use the specified encoder level.
 183  	// The dictionary will be built using the specified encoder level,
 184  	// which will reflect speed and make the dictionary tailored for that level.
 185  	// If not set SpeedBestCompression will be used.
 186  	Level EncoderLevel
 187  
 188  	// DebugOut will write stats and other details here if set.
 189  	DebugOut io.Writer
 190  }
 191  
 192  func BuildDict(o BuildDictOptions) ([]byte, error) {
 193  	initPredefined()
 194  	hist := o.History
 195  	contents := o.Contents
 196  	debug := o.DebugOut != nil
 197  	println := func(args ...any) {
 198  		if o.DebugOut != nil {
 199  			fmt.Fprintln(o.DebugOut, args...)
 200  		}
 201  	}
 202  	printf := func(s string, args ...any) {
 203  		if o.DebugOut != nil {
 204  			fmt.Fprintf(o.DebugOut, s, args...)
 205  		}
 206  	}
 207  	print := func(args ...any) {
 208  		if o.DebugOut != nil {
 209  			fmt.Fprint(o.DebugOut, args...)
 210  		}
 211  	}
 212  
 213  	if int64(len(hist)) > dictMaxLength {
 214  		return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
 215  	}
 216  	if len(hist) < 8 {
 217  		return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
 218  	}
 219  	if len(contents) == 0 {
 220  		return nil, errors.New("no content provided")
 221  	}
 222  	d := dict{
 223  		id:      o.ID,
 224  		litEnc:  nil,
 225  		llDec:   sequenceDec{},
 226  		ofDec:   sequenceDec{},
 227  		mlDec:   sequenceDec{},
 228  		offsets: o.Offsets,
 229  		content: hist,
 230  	}
 231  	block := blockEnc{lowMem: false}
 232  	block.init()
 233  	enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
 234  	if o.Level != 0 {
 235  		eOpts := encoderOptions{
 236  			level:      o.Level,
 237  			blockSize:  maxMatchLen,
 238  			windowSize: maxMatchLen,
 239  			dict:       &d,
 240  			lowMem:     false,
 241  		}
 242  		enc = eOpts.encoder()
 243  	} else {
 244  		o.Level = SpeedBestCompression
 245  	}
 246  	var (
 247  		remain [256]int
 248  		ll     [256]int
 249  		ml     [256]int
 250  		of     [256]int
 251  	)
 252  	addValues := func(dst *[256]int, src []byte) {
 253  		for _, v := range src {
 254  			dst[v]++
 255  		}
 256  	}
 257  	addHist := func(dst *[256]int, src *[256]uint32) {
 258  		for i, v := range src {
 259  			dst[i] += int(v)
 260  		}
 261  	}
 262  	seqs := 0
 263  	nUsed := 0
 264  	litTotal := 0
 265  	newOffsets := make(map[uint32]int, 1000)
 266  	for _, b := range contents {
 267  		block.reset(nil)
 268  		if len(b) < 8 {
 269  			continue
 270  		}
 271  		nUsed++
 272  		enc.Reset(&d, true)
 273  		enc.Encode(&block, b)
 274  		addValues(&remain, block.literals)
 275  		litTotal += len(block.literals)
 276  		if len(block.sequences) == 0 {
 277  			continue
 278  		}
 279  		seqs += len(block.sequences)
 280  		block.genCodes()
 281  		addHist(&ll, block.coders.llEnc.Histogram())
 282  		addHist(&ml, block.coders.mlEnc.Histogram())
 283  		addHist(&of, block.coders.ofEnc.Histogram())
 284  		for i, seq := range block.sequences {
 285  			if i > 3 {
 286  				break
 287  			}
 288  			offset := seq.offset
 289  			if offset == 0 {
 290  				continue
 291  			}
 292  			if int(offset) >= len(o.History) {
 293  				continue
 294  			}
 295  			if offset > 3 {
 296  				newOffsets[offset-3]++
 297  			} else {
 298  				newOffsets[uint32(o.Offsets[offset-1])]++
 299  			}
 300  		}
 301  	}
 302  	// Find most used offsets.
 303  	var sortedOffsets []uint32
 304  	for k := range newOffsets {
 305  		sortedOffsets = append(sortedOffsets, k)
 306  	}
 307  	sort.Slice(sortedOffsets, func(i, j int) bool {
 308  		a, b := sortedOffsets[i], sortedOffsets[j]
 309  		if a == b {
 310  			// Prefer the longer offset
 311  			return sortedOffsets[i] > sortedOffsets[j]
 312  		}
 313  		return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
 314  	})
 315  	if len(sortedOffsets) > 3 {
 316  		if debug {
 317  			print("Offsets:")
 318  			for i, v := range sortedOffsets {
 319  				if i > 20 {
 320  					break
 321  				}
 322  				printf("[%d: %d],", v, newOffsets[v])
 323  			}
 324  			println("")
 325  		}
 326  
 327  		sortedOffsets = sortedOffsets[:3]
 328  	}
 329  	for i, v := range sortedOffsets {
 330  		o.Offsets[i] = int(v)
 331  	}
 332  	if debug {
 333  		println("New repeat offsets", o.Offsets)
 334  	}
 335  
 336  	if nUsed == 0 || seqs == 0 {
 337  		return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
 338  	}
 339  	if debug {
 340  		println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
 341  	}
 342  	if seqs/nUsed < 512 {
 343  		// Use 512 as minimum.
 344  		nUsed = seqs / 512
 345  		if nUsed == 0 {
 346  			nUsed = 1
 347  		}
 348  	}
 349  	copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
 350  		hist := dst.Histogram()
 351  		var maxSym uint8
 352  		var maxCount int
 353  		var fakeLength int
 354  		for i, v := range src {
 355  			if v > 0 {
 356  				v = v / nUsed
 357  				if v == 0 {
 358  					v = 1
 359  				}
 360  			}
 361  			if v > maxCount {
 362  				maxCount = v
 363  			}
 364  			if v != 0 {
 365  				maxSym = uint8(i)
 366  			}
 367  			fakeLength += v
 368  			hist[i] = uint32(v)
 369  		}
 370  
 371  		// Ensure we aren't trying to represent RLE.
 372  		if maxCount == fakeLength {
 373  			for i := range hist {
 374  				if uint8(i) == maxSym {
 375  					fakeLength++
 376  					maxSym++
 377  					hist[i+1] = 1
 378  					if maxSym > 1 {
 379  						break
 380  					}
 381  				}
 382  				if hist[0] == 0 {
 383  					fakeLength++
 384  					hist[i] = 1
 385  					if maxSym > 1 {
 386  						break
 387  					}
 388  				}
 389  			}
 390  		}
 391  
 392  		dst.HistogramFinished(maxSym, maxCount)
 393  		dst.reUsed = false
 394  		dst.useRLE = false
 395  		err := dst.normalizeCount(fakeLength)
 396  		if err != nil {
 397  			return nil, err
 398  		}
 399  		if debug {
 400  			println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
 401  		}
 402  		return dst.writeCount(nil)
 403  	}
 404  	if debug {
 405  		print("Literal lengths: ")
 406  	}
 407  	llTable, err := copyHist(block.coders.llEnc, &ll)
 408  	if err != nil {
 409  		return nil, err
 410  	}
 411  	if debug {
 412  		print("Match lengths: ")
 413  	}
 414  	mlTable, err := copyHist(block.coders.mlEnc, &ml)
 415  	if err != nil {
 416  		return nil, err
 417  	}
 418  	if debug {
 419  		print("Offsets: ")
 420  	}
 421  	ofTable, err := copyHist(block.coders.ofEnc, &of)
 422  	if err != nil {
 423  		return nil, err
 424  	}
 425  
 426  	// Literal table
 427  	avgSize := min(litTotal, huff0.BlockSizeMax/2)
 428  	huffBuff := make([]byte, 0, avgSize)
 429  	// Target size
 430  	div := max(litTotal/avgSize, 1)
 431  	if debug {
 432  		println("Huffman weights:")
 433  	}
 434  	for i, n := range remain[:] {
 435  		if n > 0 {
 436  			n = n / div
 437  			// Allow all entries to be represented.
 438  			if n == 0 {
 439  				n = 1
 440  			}
 441  			huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
 442  			if debug {
 443  				printf("[%d: %d], ", i, n)
 444  			}
 445  		}
 446  	}
 447  	if o.CompatV155 && remain[255]/div == 0 {
 448  		huffBuff = append(huffBuff, 255)
 449  	}
 450  	scratch := &huff0.Scratch{TableLog: 11}
 451  	for tries := range 255 {
 452  		scratch = &huff0.Scratch{TableLog: 11}
 453  		_, _, err = huff0.Compress1X(huffBuff, scratch)
 454  		if err == nil {
 455  			break
 456  		}
 457  		if debug {
 458  			printf("Try %d: Huffman error: %v\n", tries+1, err)
 459  		}
 460  		huffBuff = huffBuff[:0]
 461  		if tries == 250 {
 462  			if debug {
 463  				println("Huffman: Bailing out with predefined table")
 464  			}
 465  
 466  			// Bail out.... Just generate something
 467  			huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
 468  			for i := range 128 {
 469  				huffBuff = append(huffBuff, byte(i))
 470  			}
 471  			continue
 472  		}
 473  		if errors.Is(err, huff0.ErrIncompressible) {
 474  			// Try truncating least common.
 475  			for i, n := range remain[:] {
 476  				if n > 0 {
 477  					n = n / (div * (i + 1))
 478  					if n > 0 {
 479  						huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
 480  					}
 481  				}
 482  			}
 483  			if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
 484  				huffBuff = append(huffBuff, 255)
 485  			}
 486  			if len(huffBuff) == 0 {
 487  				huffBuff = append(huffBuff, 0, 255)
 488  			}
 489  		}
 490  		if errors.Is(err, huff0.ErrUseRLE) {
 491  			for i, n := range remain[:] {
 492  				n = n / (div * (i + 1))
 493  				// Allow all entries to be represented.
 494  				if n == 0 {
 495  					n = 1
 496  				}
 497  				huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
 498  			}
 499  		}
 500  	}
 501  
 502  	var out bytes.Buffer
 503  	out.Write([]byte(dictMagic))
 504  	out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
 505  	out.Write(scratch.OutTable)
 506  	if debug {
 507  		println("huff table:", len(scratch.OutTable), "bytes")
 508  		println("of table:", len(ofTable), "bytes")
 509  		println("ml table:", len(mlTable), "bytes")
 510  		println("ll table:", len(llTable), "bytes")
 511  	}
 512  	out.Write(ofTable)
 513  	out.Write(mlTable)
 514  	out.Write(llTable)
 515  	out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
 516  	out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
 517  	out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
 518  	out.Write(hist)
 519  	if debug {
 520  		_, err := loadDict(out.Bytes())
 521  		if err != nil {
 522  			panic(err)
 523  		}
 524  		i, err := InspectDictionary(out.Bytes())
 525  		if err != nil {
 526  			panic(err)
 527  		}
 528  		println("ID:", i.ID())
 529  		println("Content size:", i.ContentSize())
 530  		println("Encoder:", i.LitEncoder() != nil)
 531  		println("Offsets:", i.Offsets())
 532  		var totalSize int
 533  		for _, b := range contents {
 534  			totalSize += len(b)
 535  		}
 536  
 537  		encWith := func(opts ...EOption) int {
 538  			enc, err := NewWriter(nil, opts...)
 539  			if err != nil {
 540  				panic(err)
 541  			}
 542  			defer enc.Close()
 543  			var dst []byte
 544  			var totalSize int
 545  			for _, b := range contents {
 546  				dst = enc.EncodeAll(b, dst[:0])
 547  				totalSize += len(dst)
 548  			}
 549  			return totalSize
 550  		}
 551  		plain := encWith(WithEncoderLevel(o.Level))
 552  		withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
 553  		println("Input size:", totalSize)
 554  		println("Plain Compressed:", plain)
 555  		println("Dict Compressed:", withDict)
 556  		println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
 557  	}
 558  	return out.Bytes(), nil
 559  }
 560