huffman.mx raw

   1  // Copyright 2014 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package hpack
   6  
   7  import (
   8  	"bytes"
   9  	"errors"
  10  	"io"
  11  	"sync"
  12  )
  13  
  14  var bufPool = sync.Pool{
  15  	New: func() interface{} { return &bytes.Buffer{} },
  16  }
  17  
  18  // HuffmanDecode decodes the string in v and writes the expanded
  19  // result to w, returning the number of bytes written to w and the
  20  // Write call's return value. At most one Write call is made.
  21  func HuffmanDecode(w io.Writer, v []byte) (int, error) {
  22  	buf := bufPool.Get().(*bytes.Buffer)
  23  	buf.Reset()
  24  	defer bufPool.Put(buf)
  25  	if err := huffmanDecode(buf, 0, v); err != nil {
  26  		return 0, err
  27  	}
  28  	return w.Write(buf.Bytes())
  29  }
  30  
  31  // HuffmanDecodeToString decodes the string in v.
  32  func HuffmanDecodeToString(v []byte) ([]byte, error) {
  33  	buf := bufPool.Get().(*bytes.Buffer)
  34  	buf.Reset()
  35  	defer bufPool.Put(buf)
  36  	if err := huffmanDecode(buf, 0, v); err != nil {
  37  		return "", err
  38  	}
  39  	return buf.String(), nil
  40  }
  41  
  42  // ErrInvalidHuffman is returned for errors found decoding
  43  // Huffman-encoded strings.
  44  var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
  45  
  46  // huffmanDecode decodes v to buf.
  47  // If maxLen is greater than 0, attempts to write more to buf than
  48  // maxLen bytes will return ErrStringLength.
  49  func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
  50  	rootHuffmanNode := getRootHuffmanNode()
  51  	n := rootHuffmanNode
  52  	// cur is the bit buffer that has not been fed into n.
  53  	// cbits is the number of low order bits in cur that are valid.
  54  	// sbits is the number of bits of the symbol prefix being decoded.
  55  	cur, cbits, sbits := uint(0), uint8(0), uint8(0)
  56  	for _, b := range v {
  57  		cur = cur<<8 | uint(b)
  58  		cbits += 8
  59  		sbits += 8
  60  		for cbits >= 8 {
  61  			idx := byte(cur >> (cbits - 8))
  62  			n = n.children[idx]
  63  			if n == nil {
  64  				return ErrInvalidHuffman
  65  			}
  66  			if n.children == nil {
  67  				if maxLen != 0 && buf.Len() == maxLen {
  68  					return ErrStringLength
  69  				}
  70  				buf.WriteByte(n.sym)
  71  				cbits -= n.codeLen
  72  				n = rootHuffmanNode
  73  				sbits = cbits
  74  			} else {
  75  				cbits -= 8
  76  			}
  77  		}
  78  	}
  79  	for cbits > 0 {
  80  		n = n.children[byte(cur<<(8-cbits))]
  81  		if n == nil {
  82  			return ErrInvalidHuffman
  83  		}
  84  		if n.children != nil || n.codeLen > cbits {
  85  			break
  86  		}
  87  		if maxLen != 0 && buf.Len() == maxLen {
  88  			return ErrStringLength
  89  		}
  90  		buf.WriteByte(n.sym)
  91  		cbits -= n.codeLen
  92  		n = rootHuffmanNode
  93  		sbits = cbits
  94  	}
  95  	if sbits > 7 {
  96  		// Either there was an incomplete symbol, or overlong padding.
  97  		// Both are decoding errors per RFC 7541 section 5.2.
  98  		return ErrInvalidHuffman
  99  	}
 100  	if mask := uint(1<<cbits - 1); cur&mask != mask {
 101  		// Trailing bits must be a prefix of EOS per RFC 7541 section 5.2.
 102  		return ErrInvalidHuffman
 103  	}
 104  
 105  	return nil
 106  }
 107  
 108  // incomparable is a zero-width, non-comparable type. Adding it to a struct
 109  // makes that struct also non-comparable, and generally doesn't add
 110  // any size (as long as it's first).
 111  type incomparable [0]func()
 112  
 113  type node struct {
 114  	_ incomparable
 115  
 116  	// children is non-nil for internal nodes
 117  	children *[256]*node
 118  
 119  	// The following are only valid if children is nil:
 120  	codeLen uint8 // number of bits that led to the output of sym
 121  	sym     byte  // output symbol
 122  }
 123  
 124  func newInternalNode() *node {
 125  	return &node{children: &[256]*node{}}
 126  }
 127  
 128  var (
 129  	buildRootOnce       sync.Once
 130  	lazyRootHuffmanNode *node
 131  )
 132  
 133  func getRootHuffmanNode() *node {
 134  	buildRootOnce.Do(buildRootHuffmanNode)
 135  	return lazyRootHuffmanNode
 136  }
 137  
 138  func buildRootHuffmanNode() {
 139  	if len(huffmanCodes) != 256 {
 140  		panic("unexpected size")
 141  	}
 142  	lazyRootHuffmanNode = newInternalNode()
 143  	// allocate a leaf node for each of the 256 symbols
 144  	leaves := &[256]node{}
 145  
 146  	for sym, code := range huffmanCodes {
 147  		codeLen := huffmanCodeLen[sym]
 148  
 149  		cur := lazyRootHuffmanNode
 150  		for codeLen > 8 {
 151  			codeLen -= 8
 152  			i := uint8(code >> codeLen)
 153  			if cur.children[i] == nil {
 154  				cur.children[i] = newInternalNode()
 155  			}
 156  			cur = cur.children[i]
 157  		}
 158  		shift := 8 - codeLen
 159  		start, end := int(uint8(code<<shift)), int(1<<shift)
 160  
 161  		leaves[sym].sym = byte(sym)
 162  		leaves[sym].codeLen = codeLen
 163  		for i := start; i < start+end; i++ {
 164  			cur.children[i] = &leaves[sym]
 165  		}
 166  	}
 167  }
 168  
 169  // AppendHuffmanString appends s, as encoded in Huffman codes, to dst
 170  // and returns the extended buffer.
 171  func AppendHuffmanString(dst []byte, s []byte) []byte {
 172  	// This relies on the maximum huffman code length being 30 (See tables.go huffmanCodeLen array)
 173  	// So if a uint64 buffer has less than 32 valid bits can always accommodate another huffmanCode.
 174  	var (
 175  		x uint64 // buffer
 176  		n uint   // number valid of bits present in x
 177  	)
 178  	for i := 0; i < len(s); i++ {
 179  		c := s[i]
 180  		n += uint(huffmanCodeLen[c])
 181  		x <<= huffmanCodeLen[c] % 64
 182  		x |= uint64(huffmanCodes[c])
 183  		if n >= 32 {
 184  			n %= 32             // Normally would be -= 32 but %= 32 informs compiler 0 <= n <= 31 for upcoming shift
 185  			y := uint32(x >> n) // Compiler doesn't combine memory writes if y isn't uint32
 186  			dst = append(dst, byte(y>>24), byte(y>>16), byte(y>>8), byte(y))
 187  		}
 188  	}
 189  	// Add padding bits if necessary
 190  	if over := n % 8; over > 0 {
 191  		const (
 192  			eosCode    = 0x3fffffff
 193  			eosNBits   = 30
 194  			eosPadByte = eosCode >> (eosNBits - 8)
 195  		)
 196  		pad := 8 - over
 197  		x = (x << pad) | (eosPadByte >> over)
 198  		n += pad // 8 now divides into n exactly
 199  	}
 200  	// n in (0, 8, 16, 24, 32)
 201  	switch n / 8 {
 202  	case 0:
 203  		return dst
 204  	case 1:
 205  		return append(dst, byte(x))
 206  	case 2:
 207  		y := uint16(x)
 208  		return append(dst, byte(y>>8), byte(y))
 209  	case 3:
 210  		y := uint16(x >> 8)
 211  		return append(dst, byte(y>>8), byte(y), byte(x))
 212  	}
 213  	//	case 4:
 214  	y := uint32(x)
 215  	return append(dst, byte(y>>24), byte(y>>16), byte(y>>8), byte(y))
 216  }
 217  
 218  // HuffmanEncodeLength returns the number of bytes required to encode
 219  // s in Huffman codes. The result is round up to byte boundary.
 220  func HuffmanEncodeLength(s []byte) uint64 {
 221  	n := uint64(0)
 222  	for i := 0; i < len(s); i++ {
 223  		n += uint64(huffmanCodeLen[s[i]])
 224  	}
 225  	return (n + 7) / 8
 226  }
 227