decompress_amd64.go raw
1 //go:build amd64 && !appengine && !noasm && gc
2 // +build amd64,!appengine,!noasm,gc
3
4 // This file contains the specialisation of Decoder.Decompress4X
5 // and Decoder.Decompress1X that use an asm implementation of thir main loops.
6 package huff0
7
8 import (
9 "errors"
10 "fmt"
11
12 "github.com/klauspost/compress/internal/cpuinfo"
13 )
14
15 // decompress4x_main_loop_x86 is an x86 assembler implementation
16 // of Decompress4X when tablelog > 8.
17 //
18 //go:noescape
19 func decompress4x_main_loop_amd64(ctx *decompress4xContext)
20
21 // decompress4x_8b_loop_x86 is an x86 assembler implementation
22 // of Decompress4X when tablelog <= 8 which decodes 4 entries
23 // per loop.
24 //
25 //go:noescape
26 func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext)
27
28 // fallback8BitSize is the size where using Go version is faster.
29 const fallback8BitSize = 800
30
31 type decompress4xContext struct {
32 pbr *[4]bitReaderShifted
33 peekBits uint8
34 out *byte
35 dstEvery int
36 tbl *dEntrySingle
37 decoded int
38 limit *byte
39 }
40
41 // Decompress4X will decompress a 4X encoded stream.
42 // The length of the supplied input must match the end of a block exactly.
43 // The *capacity* of the dst slice must match the destination size of
44 // the uncompressed data exactly.
45 func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
46 if len(d.dt.single) == 0 {
47 return nil, errors.New("no table loaded")
48 }
49 if len(src) < 6+(4*1) {
50 return nil, errors.New("input too small")
51 }
52
53 use8BitTables := d.actualTableLog <= 8
54 if cap(dst) < fallback8BitSize && use8BitTables {
55 return d.decompress4X8bit(dst, src)
56 }
57
58 var br [4]bitReaderShifted
59 // Decode "jump table"
60 start := 6
61 for i := range 3 {
62 length := int(src[i*2]) | (int(src[i*2+1]) << 8)
63 if start+length >= len(src) {
64 return nil, errors.New("truncated input (or invalid offset)")
65 }
66 err := br[i].init(src[start : start+length])
67 if err != nil {
68 return nil, err
69 }
70 start += length
71 }
72 err := br[3].init(src[start:])
73 if err != nil {
74 return nil, err
75 }
76
77 // destination, offset to match first output
78 dstSize := cap(dst)
79 dst = dst[:dstSize]
80 out := dst
81 dstEvery := (dstSize + 3) / 4
82
83 const tlSize = 1 << tableLogMax
84 const tlMask = tlSize - 1
85 single := d.dt.single[:tlSize]
86
87 var decoded int
88
89 if len(out) > 4*4 && !(br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4) {
90 ctx := decompress4xContext{
91 pbr: &br,
92 peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
93 out: &out[0],
94 dstEvery: dstEvery,
95 tbl: &single[0],
96 limit: &out[dstEvery-4], // Always stop decoding when first buffer gets here to avoid writing OOB on last.
97 }
98 if use8BitTables {
99 decompress4x_8b_main_loop_amd64(&ctx)
100 } else {
101 decompress4x_main_loop_amd64(&ctx)
102 }
103
104 decoded = ctx.decoded
105 out = out[decoded/4:]
106 }
107
108 // Decode remaining.
109 remainBytes := dstEvery - (decoded / 4)
110 for i := range br {
111 offset := dstEvery * i
112 endsAt := min(offset+remainBytes, len(out))
113 br := &br[i]
114 bitsLeft := br.remaining()
115 for bitsLeft > 0 {
116 br.fill()
117 if offset >= endsAt {
118 return nil, errors.New("corruption detected: stream overrun 4")
119 }
120
121 // Read value and increment offset.
122 val := br.peekBitsFast(d.actualTableLog)
123 v := single[val&tlMask].entry
124 nBits := uint8(v)
125 br.advance(nBits)
126 bitsLeft -= uint(nBits)
127 out[offset] = uint8(v >> 8)
128 offset++
129 }
130 if offset != endsAt {
131 return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
132 }
133 decoded += offset - dstEvery*i
134 err = br.close()
135 if err != nil {
136 return nil, err
137 }
138 }
139 if dstSize != decoded {
140 return nil, errors.New("corruption detected: short output block")
141 }
142 return dst, nil
143 }
144
145 // decompress4x_main_loop_x86 is an x86 assembler implementation
146 // of Decompress1X when tablelog > 8.
147 //
148 //go:noescape
149 func decompress1x_main_loop_amd64(ctx *decompress1xContext)
150
151 // decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation
152 // of Decompress1X when tablelog > 8.
153 //
154 //go:noescape
155 func decompress1x_main_loop_bmi2(ctx *decompress1xContext)
156
157 type decompress1xContext struct {
158 pbr *bitReaderShifted
159 peekBits uint8
160 out *byte
161 outCap int
162 tbl *dEntrySingle
163 decoded int
164 }
165
166 // Error reported by asm implementations
167 const error_max_decoded_size_exeeded = -1
168
169 // Decompress1X will decompress a 1X encoded stream.
170 // The cap of the output buffer will be the maximum decompressed size.
171 // The length of the supplied input must match the end of a block exactly.
172 func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
173 if len(d.dt.single) == 0 {
174 return nil, errors.New("no table loaded")
175 }
176 var br bitReaderShifted
177 err := br.init(src)
178 if err != nil {
179 return dst, err
180 }
181 maxDecodedSize := cap(dst)
182 dst = dst[:maxDecodedSize]
183
184 const tlSize = 1 << tableLogMax
185 const tlMask = tlSize - 1
186
187 if maxDecodedSize >= 4 {
188 ctx := decompress1xContext{
189 pbr: &br,
190 out: &dst[0],
191 outCap: maxDecodedSize,
192 peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
193 tbl: &d.dt.single[0],
194 }
195
196 if cpuinfo.HasBMI2() {
197 decompress1x_main_loop_bmi2(&ctx)
198 } else {
199 decompress1x_main_loop_amd64(&ctx)
200 }
201 if ctx.decoded == error_max_decoded_size_exeeded {
202 return nil, ErrMaxDecodedSizeExceeded
203 }
204
205 dst = dst[:ctx.decoded]
206 }
207
208 // br < 8, so uint8 is fine
209 bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
210 for bitsLeft > 0 {
211 br.fill()
212 if len(dst) >= maxDecodedSize {
213 br.close()
214 return nil, ErrMaxDecodedSizeExceeded
215 }
216 v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
217 nBits := uint8(v.entry)
218 br.advance(nBits)
219 bitsLeft -= nBits
220 dst = append(dst, uint8(v.entry>>8))
221 }
222 return dst, br.close()
223 }
224