1 // Copyright 2019+ Klaus Post. All rights reserved.
2 // License information can be found in the LICENSE file.
3 // Based on work by Yann Collet, released under BSD License.
4 5 package zstd
6 7 import (
8 "errors"
9 "fmt"
10 "io"
11 "math/bits"
12 13 "github.com/klauspost/compress/internal/le"
14 )
15 16 // bitReader reads a bitstream in reverse.
17 // The last set bit indicates the start of the stream and is used
18 // for aligning the input.
19 type bitReader struct {
20 in []byte
21 value uint64 // Maybe use [16]byte, but shifting is awkward.
22 cursor int // offset where next read should end
23 bitsRead uint8
24 }
25 26 // init initializes and resets the bit reader.
27 func (b *bitReader) init(in []byte) error {
28 if len(in) < 1 {
29 return errors.New("corrupt stream: too short")
30 }
31 b.in = in
32 // The highest bit of the last byte indicates where to start
33 v := in[len(in)-1]
34 if v == 0 {
35 return errors.New("corrupt stream, did not find end of stream")
36 }
37 b.cursor = len(in)
38 b.bitsRead = 64
39 b.value = 0
40 if len(in) >= 8 {
41 b.fillFastStart()
42 } else {
43 b.fill()
44 b.fill()
45 }
46 b.bitsRead += 8 - uint8(highBits(uint32(v)))
47 return nil
48 }
49 50 // getBits will return n bits. n can be 0.
51 func (b *bitReader) getBits(n uint8) int {
52 if n == 0 /*|| b.bitsRead >= 64 */ {
53 return 0
54 }
55 return int(b.get32BitsFast(n))
56 }
57 58 // get32BitsFast requires that at least one bit is requested every time.
59 // There are no checks if the buffer is filled.
60 func (b *bitReader) get32BitsFast(n uint8) uint32 {
61 const regMask = 64 - 1
62 v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
63 b.bitsRead += n
64 return v
65 }
66 67 // fillFast() will make sure at least 32 bits are available.
68 // There must be at least 4 bytes available.
69 func (b *bitReader) fillFast() {
70 if b.bitsRead < 32 {
71 return
72 }
73 b.cursor -= 4
74 b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
75 b.bitsRead -= 32
76 }
77 78 // fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
79 func (b *bitReader) fillFastStart() {
80 b.cursor -= 8
81 b.value = le.Load64(b.in, b.cursor)
82 b.bitsRead = 0
83 }
84 85 // fill() will make sure at least 32 bits are available.
86 func (b *bitReader) fill() {
87 if b.bitsRead < 32 {
88 return
89 }
90 if b.cursor >= 4 {
91 b.cursor -= 4
92 b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
93 b.bitsRead -= 32
94 return
95 }
96 97 b.bitsRead -= uint8(8 * b.cursor)
98 for b.cursor > 0 {
99 b.cursor -= 1
100 b.value = (b.value << 8) | uint64(b.in[b.cursor])
101 }
102 }
103 104 // finished returns true if all bits have been read from the bit stream.
105 func (b *bitReader) finished() bool {
106 return b.cursor == 0 && b.bitsRead >= 64
107 }
108 109 // overread returns true if more bits have been requested than is on the stream.
110 func (b *bitReader) overread() bool {
111 return b.bitsRead > 64
112 }
113 114 // remain returns the number of bits remaining.
115 func (b *bitReader) remain() uint {
116 return 8*uint(b.cursor) + 64 - uint(b.bitsRead)
117 }
118 119 // close the bitstream and returns an error if out-of-buffer reads occurred.
120 func (b *bitReader) close() error {
121 // Release reference.
122 b.in = nil
123 b.cursor = 0
124 if !b.finished() {
125 return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
126 }
127 if b.bitsRead > 64 {
128 return io.ErrUnexpectedEOF
129 }
130 return nil
131 }
132 133 func highBits(val uint32) (n uint32) {
134 return uint32(bits.Len32(val) - 1)
135 }
136