writer.mx raw

   1  // Copyright 2011 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package lzw
   6  
   7  import (
   8  	"bufio"
   9  	"errors"
  10  	"fmt"
  11  	"io"
  12  )
  13  
  14  // A writer is a buffered, flushable writer.
  15  type writer interface {
  16  	io.ByteWriter
  17  	Flush() error
  18  }
  19  
  20  const (
  21  	// A code is a 12 bit value, stored as a uint32 when encoding to avoid
  22  	// type conversions when shifting bits.
  23  	maxCode     = 1<<12 - 1
  24  	invalidCode = 1<<32 - 1
  25  	// There are 1<<12 possible codes, which is an upper bound on the number of
  26  	// valid hash table entries at any given point in time. tableSize is 4x that.
  27  	tableSize = 4 * 1 << 12
  28  	tableMask = tableSize - 1
  29  	// A hash table entry is a uint32. Zero is an invalid entry since the
  30  	// lower 12 bits of a valid entry must be a non-literal code.
  31  	invalidEntry = 0
  32  )
  33  
  34  // Writer is an LZW compressor. It writes the compressed form of the data
  35  // to an underlying writer (see [NewWriter]).
  36  type Writer struct {
  37  	// w is the writer that compressed bytes are written to.
  38  	w writer
  39  	// litWidth is the width in bits of literal codes.
  40  	litWidth uint
  41  	// order, write, bits, nBits and width are the state for
  42  	// converting a code stream into a byte stream.
  43  	order Order
  44  	write func(*Writer, uint32) error
  45  	nBits uint
  46  	width uint
  47  	bits  uint32
  48  	// hi is the code implied by the next code emission.
  49  	// overflow is the code at which hi overflows the code width.
  50  	hi, overflow uint32
  51  	// savedCode is the accumulated code at the end of the most recent Write
  52  	// call. It is equal to invalidCode if there was no such call.
  53  	savedCode uint32
  54  	// err is the first error encountered during writing. Closing the writer
  55  	// will make any future Write calls return errClosed
  56  	err error
  57  	// table is the hash table from 20-bit keys to 12-bit values. Each table
  58  	// entry contains key<<12|val and collisions resolve by linear probing.
  59  	// The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
  60  	// The values are a 12-bit code.
  61  	table [tableSize]uint32
  62  }
  63  
  64  // writeLSB writes the code c for "Least Significant Bits first" data.
  65  func (w *Writer) writeLSB(c uint32) error {
  66  	w.bits |= c << w.nBits
  67  	w.nBits += w.width
  68  	for w.nBits >= 8 {
  69  		if err := w.w.WriteByte(uint8(w.bits)); err != nil {
  70  			return err
  71  		}
  72  		w.bits >>= 8
  73  		w.nBits -= 8
  74  	}
  75  	return nil
  76  }
  77  
  78  // writeMSB writes the code c for "Most Significant Bits first" data.
  79  func (w *Writer) writeMSB(c uint32) error {
  80  	w.bits |= c << (32 - w.width - w.nBits)
  81  	w.nBits += w.width
  82  	for w.nBits >= 8 {
  83  		if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
  84  			return err
  85  		}
  86  		w.bits <<= 8
  87  		w.nBits -= 8
  88  	}
  89  	return nil
  90  }
  91  
  92  // errOutOfCodes is an internal error that means that the writer has run out
  93  // of unused codes and a clear code needs to be sent next.
  94  var errOutOfCodes = errors.New("lzw: out of codes")
  95  
  96  // incHi increments e.hi and checks for both overflow and running out of
  97  // unused codes. In the latter case, incHi sends a clear code, resets the
  98  // writer state and returns errOutOfCodes.
  99  func (w *Writer) incHi() error {
 100  	w.hi++
 101  	if w.hi == w.overflow {
 102  		w.width++
 103  		w.overflow <<= 1
 104  	}
 105  	if w.hi == maxCode {
 106  		clear := uint32(1) << w.litWidth
 107  		if err := w.write(w, clear); err != nil {
 108  			return err
 109  		}
 110  		w.width = w.litWidth + 1
 111  		w.hi = clear + 1
 112  		w.overflow = clear << 1
 113  		for i := range w.table {
 114  			w.table[i] = invalidEntry
 115  		}
 116  		return errOutOfCodes
 117  	}
 118  	return nil
 119  }
 120  
 121  // Write writes a compressed representation of p to w's underlying writer.
 122  func (w *Writer) Write(p []byte) (n int, err error) {
 123  	if w.err != nil {
 124  		return 0, w.err
 125  	}
 126  	if len(p) == 0 {
 127  		return 0, nil
 128  	}
 129  	if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
 130  		for _, x := range p {
 131  			if x > maxLit {
 132  				w.err = errors.New("lzw: input byte too large for the litWidth")
 133  				return 0, w.err
 134  			}
 135  		}
 136  	}
 137  	n = len(p)
 138  	code := w.savedCode
 139  	if code == invalidCode {
 140  		// This is the first write; send a clear code.
 141  		// https://www.w3.org/Graphics/GIF/spec-gif89a.txt Appendix F
 142  		// "Variable-Length-Code LZW Compression" says that "Encoders should
 143  		// output a Clear code as the first code of each image data stream".
 144  		//
 145  		// LZW compression isn't only used by GIF, but it's cheap to follow
 146  		// that directive unconditionally.
 147  		clear := uint32(1) << w.litWidth
 148  		if err := w.write(w, clear); err != nil {
 149  			return 0, err
 150  		}
 151  		// After the starting clear code, the next code sent (for non-empty
 152  		// input) is always a literal code.
 153  		code, p = uint32(p[0]), p[1:]
 154  	}
 155  loop:
 156  	for _, x := range p {
 157  		literal := uint32(x)
 158  		key := code<<8 | literal
 159  		// If there is a hash table hit for this key then we continue the loop
 160  		// and do not emit a code yet.
 161  		hash := (key>>12 ^ key) & tableMask
 162  		for h, t := hash, w.table[hash]; t != invalidEntry; {
 163  			if key == t>>12 {
 164  				code = t & maxCode
 165  				continue loop
 166  			}
 167  			h = (h + 1) & tableMask
 168  			t = w.table[h]
 169  		}
 170  		// Otherwise, write the current code, and literal becomes the start of
 171  		// the next emitted code.
 172  		if w.err = w.write(w, code); w.err != nil {
 173  			return 0, w.err
 174  		}
 175  		code = literal
 176  		// Increment e.hi, the next implied code. If we run out of codes, reset
 177  		// the writer state (including clearing the hash table) and continue.
 178  		if err1 := w.incHi(); err1 != nil {
 179  			if err1 == errOutOfCodes {
 180  				continue
 181  			}
 182  			w.err = err1
 183  			return 0, w.err
 184  		}
 185  		// Otherwise, insert key -> e.hi into the map that e.table represents.
 186  		for {
 187  			if w.table[hash] == invalidEntry {
 188  				w.table[hash] = (key << 12) | w.hi
 189  				break
 190  			}
 191  			hash = (hash + 1) & tableMask
 192  		}
 193  	}
 194  	w.savedCode = code
 195  	return n, nil
 196  }
 197  
 198  // Close closes the [Writer], flushing any pending output. It does not close
 199  // w's underlying writer.
 200  func (w *Writer) Close() error {
 201  	if w.err != nil {
 202  		if w.err == errClosed {
 203  			return nil
 204  		}
 205  		return w.err
 206  	}
 207  	// Make any future calls to Write return errClosed.
 208  	w.err = errClosed
 209  	// Write the savedCode if valid.
 210  	if w.savedCode != invalidCode {
 211  		if err := w.write(w, w.savedCode); err != nil {
 212  			return err
 213  		}
 214  		if err := w.incHi(); err != nil && err != errOutOfCodes {
 215  			return err
 216  		}
 217  	} else {
 218  		// Write the starting clear code, as w.Write did not.
 219  		clear := uint32(1) << w.litWidth
 220  		if err := w.write(w, clear); err != nil {
 221  			return err
 222  		}
 223  	}
 224  	// Write the eof code.
 225  	eof := uint32(1)<<w.litWidth + 1
 226  	if err := w.write(w, eof); err != nil {
 227  		return err
 228  	}
 229  	// Write the final bits.
 230  	if w.nBits > 0 {
 231  		if w.order == MSB {
 232  			w.bits >>= 24
 233  		}
 234  		if err := w.w.WriteByte(uint8(w.bits)); err != nil {
 235  			return err
 236  		}
 237  	}
 238  	return w.w.Flush()
 239  }
 240  
 241  // Reset clears the [Writer]'s state and allows it to be reused again
 242  // as a new [Writer].
 243  func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
 244  	*w = Writer{}
 245  	w.init(dst, order, litWidth)
 246  }
 247  
 248  // NewWriter creates a new [io.WriteCloser].
 249  // Writes to the returned [io.WriteCloser] are compressed and written to w.
 250  // It is the caller's responsibility to call Close on the WriteCloser when
 251  // finished writing.
 252  // The number of bits to use for literal codes, litWidth, must be in the
 253  // range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
 254  //
 255  // It is guaranteed that the underlying type of the returned [io.WriteCloser]
 256  // is a *[Writer].
 257  func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
 258  	return newWriter(w, order, litWidth)
 259  }
 260  
 261  func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
 262  	w := &Writer{}
 263  	w.init(dst, order, litWidth)
 264  	return w
 265  }
 266  
 267  func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
 268  	switch order {
 269  	case LSB:
 270  		w.write = (*Writer).writeLSB
 271  	case MSB:
 272  		w.write = (*Writer).writeMSB
 273  	default:
 274  		w.err = errors.New("lzw: unknown order")
 275  		return
 276  	}
 277  	if litWidth < 2 || 8 < litWidth {
 278  		w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
 279  		return
 280  	}
 281  	bw, ok := dst.(writer)
 282  	if !ok && dst != nil {
 283  		bw = bufio.NewWriter(dst)
 284  	}
 285  	w.w = bw
 286  	lw := uint(litWidth)
 287  	w.order = order
 288  	w.width = 1 + lw
 289  	w.litWidth = lw
 290  	w.hi = 1<<lw + 1
 291  	w.overflow = 1 << (lw + 1)
 292  	w.savedCode = invalidCode
 293  }
 294