bitreader.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"errors"
   9  	"fmt"
  10  	"io"
  11  	"math/bits"
  12  
  13  	"github.com/klauspost/compress/internal/le"
  14  )
  15  
  16  // bitReader reads a bitstream in reverse.
  17  // The last set bit indicates the start of the stream and is used
  18  // for aligning the input.
  19  type bitReader struct {
  20  	in       []byte
  21  	value    uint64 // Maybe use [16]byte, but shifting is awkward.
  22  	cursor   int    // offset where next read should end
  23  	bitsRead uint8
  24  }
  25  
  26  // init initializes and resets the bit reader.
  27  func (b *bitReader) init(in []byte) error {
  28  	if len(in) < 1 {
  29  		return errors.New("corrupt stream: too short")
  30  	}
  31  	b.in = in
  32  	// The highest bit of the last byte indicates where to start
  33  	v := in[len(in)-1]
  34  	if v == 0 {
  35  		return errors.New("corrupt stream, did not find end of stream")
  36  	}
  37  	b.cursor = len(in)
  38  	b.bitsRead = 64
  39  	b.value = 0
  40  	if len(in) >= 8 {
  41  		b.fillFastStart()
  42  	} else {
  43  		b.fill()
  44  		b.fill()
  45  	}
  46  	b.bitsRead += 8 - uint8(highBits(uint32(v)))
  47  	return nil
  48  }
  49  
  50  // getBits will return n bits. n can be 0.
  51  func (b *bitReader) getBits(n uint8) int {
  52  	if n == 0 /*|| b.bitsRead >= 64 */ {
  53  		return 0
  54  	}
  55  	return int(b.get32BitsFast(n))
  56  }
  57  
  58  // get32BitsFast requires that at least one bit is requested every time.
  59  // There are no checks if the buffer is filled.
  60  func (b *bitReader) get32BitsFast(n uint8) uint32 {
  61  	const regMask = 64 - 1
  62  	v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
  63  	b.bitsRead += n
  64  	return v
  65  }
  66  
  67  // fillFast() will make sure at least 32 bits are available.
  68  // There must be at least 4 bytes available.
  69  func (b *bitReader) fillFast() {
  70  	if b.bitsRead < 32 {
  71  		return
  72  	}
  73  	b.cursor -= 4
  74  	b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
  75  	b.bitsRead -= 32
  76  }
  77  
  78  // fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
  79  func (b *bitReader) fillFastStart() {
  80  	b.cursor -= 8
  81  	b.value = le.Load64(b.in, b.cursor)
  82  	b.bitsRead = 0
  83  }
  84  
  85  // fill() will make sure at least 32 bits are available.
  86  func (b *bitReader) fill() {
  87  	if b.bitsRead < 32 {
  88  		return
  89  	}
  90  	if b.cursor >= 4 {
  91  		b.cursor -= 4
  92  		b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
  93  		b.bitsRead -= 32
  94  		return
  95  	}
  96  
  97  	b.bitsRead -= uint8(8 * b.cursor)
  98  	for b.cursor > 0 {
  99  		b.cursor -= 1
 100  		b.value = (b.value << 8) | uint64(b.in[b.cursor])
 101  	}
 102  }
 103  
 104  // finished returns true if all bits have been read from the bit stream.
 105  func (b *bitReader) finished() bool {
 106  	return b.cursor == 0 && b.bitsRead >= 64
 107  }
 108  
 109  // overread returns true if more bits have been requested than is on the stream.
 110  func (b *bitReader) overread() bool {
 111  	return b.bitsRead > 64
 112  }
 113  
 114  // remain returns the number of bits remaining.
 115  func (b *bitReader) remain() uint {
 116  	return 8*uint(b.cursor) + 64 - uint(b.bitsRead)
 117  }
 118  
 119  // close the bitstream and returns an error if out-of-buffer reads occurred.
 120  func (b *bitReader) close() error {
 121  	// Release reference.
 122  	b.in = nil
 123  	b.cursor = 0
 124  	if !b.finished() {
 125  		return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
 126  	}
 127  	if b.bitsRead > 64 {
 128  		return io.ErrUnexpectedEOF
 129  	}
 130  	return nil
 131  }
 132  
 133  func highBits(val uint32) (n uint32) {
 134  	return uint32(bits.Len32(val) - 1)
 135  }
 136