decompress_amd64.go raw

   1  //go:build amd64 && !appengine && !noasm && gc
   2  // +build amd64,!appengine,!noasm,gc
   3  
   4  // This file contains the specialisation of Decoder.Decompress4X
   5  // and Decoder.Decompress1X that use an asm implementation of thir main loops.
   6  package huff0
   7  
   8  import (
   9  	"errors"
  10  	"fmt"
  11  
  12  	"github.com/klauspost/compress/internal/cpuinfo"
  13  )
  14  
  15  // decompress4x_main_loop_x86 is an x86 assembler implementation
  16  // of Decompress4X when tablelog > 8.
  17  //
  18  //go:noescape
  19  func decompress4x_main_loop_amd64(ctx *decompress4xContext)
  20  
  21  // decompress4x_8b_loop_x86 is an x86 assembler implementation
  22  // of Decompress4X when tablelog <= 8 which decodes 4 entries
  23  // per loop.
  24  //
  25  //go:noescape
  26  func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext)
  27  
  28  // fallback8BitSize is the size where using Go version is faster.
  29  const fallback8BitSize = 800
  30  
  31  type decompress4xContext struct {
  32  	pbr      *[4]bitReaderShifted
  33  	peekBits uint8
  34  	out      *byte
  35  	dstEvery int
  36  	tbl      *dEntrySingle
  37  	decoded  int
  38  	limit    *byte
  39  }
  40  
  41  // Decompress4X will decompress a 4X encoded stream.
  42  // The length of the supplied input must match the end of a block exactly.
  43  // The *capacity* of the dst slice must match the destination size of
  44  // the uncompressed data exactly.
  45  func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
  46  	if len(d.dt.single) == 0 {
  47  		return nil, errors.New("no table loaded")
  48  	}
  49  	if len(src) < 6+(4*1) {
  50  		return nil, errors.New("input too small")
  51  	}
  52  
  53  	use8BitTables := d.actualTableLog <= 8
  54  	if cap(dst) < fallback8BitSize && use8BitTables {
  55  		return d.decompress4X8bit(dst, src)
  56  	}
  57  
  58  	var br [4]bitReaderShifted
  59  	// Decode "jump table"
  60  	start := 6
  61  	for i := range 3 {
  62  		length := int(src[i*2]) | (int(src[i*2+1]) << 8)
  63  		if start+length >= len(src) {
  64  			return nil, errors.New("truncated input (or invalid offset)")
  65  		}
  66  		err := br[i].init(src[start : start+length])
  67  		if err != nil {
  68  			return nil, err
  69  		}
  70  		start += length
  71  	}
  72  	err := br[3].init(src[start:])
  73  	if err != nil {
  74  		return nil, err
  75  	}
  76  
  77  	// destination, offset to match first output
  78  	dstSize := cap(dst)
  79  	dst = dst[:dstSize]
  80  	out := dst
  81  	dstEvery := (dstSize + 3) / 4
  82  
  83  	const tlSize = 1 << tableLogMax
  84  	const tlMask = tlSize - 1
  85  	single := d.dt.single[:tlSize]
  86  
  87  	var decoded int
  88  
  89  	if len(out) > 4*4 && !(br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4) {
  90  		ctx := decompress4xContext{
  91  			pbr:      &br,
  92  			peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
  93  			out:      &out[0],
  94  			dstEvery: dstEvery,
  95  			tbl:      &single[0],
  96  			limit:    &out[dstEvery-4], // Always stop decoding when first buffer gets here to avoid writing OOB on last.
  97  		}
  98  		if use8BitTables {
  99  			decompress4x_8b_main_loop_amd64(&ctx)
 100  		} else {
 101  			decompress4x_main_loop_amd64(&ctx)
 102  		}
 103  
 104  		decoded = ctx.decoded
 105  		out = out[decoded/4:]
 106  	}
 107  
 108  	// Decode remaining.
 109  	remainBytes := dstEvery - (decoded / 4)
 110  	for i := range br {
 111  		offset := dstEvery * i
 112  		endsAt := min(offset+remainBytes, len(out))
 113  		br := &br[i]
 114  		bitsLeft := br.remaining()
 115  		for bitsLeft > 0 {
 116  			br.fill()
 117  			if offset >= endsAt {
 118  				return nil, errors.New("corruption detected: stream overrun 4")
 119  			}
 120  
 121  			// Read value and increment offset.
 122  			val := br.peekBitsFast(d.actualTableLog)
 123  			v := single[val&tlMask].entry
 124  			nBits := uint8(v)
 125  			br.advance(nBits)
 126  			bitsLeft -= uint(nBits)
 127  			out[offset] = uint8(v >> 8)
 128  			offset++
 129  		}
 130  		if offset != endsAt {
 131  			return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
 132  		}
 133  		decoded += offset - dstEvery*i
 134  		err = br.close()
 135  		if err != nil {
 136  			return nil, err
 137  		}
 138  	}
 139  	if dstSize != decoded {
 140  		return nil, errors.New("corruption detected: short output block")
 141  	}
 142  	return dst, nil
 143  }
 144  
 145  // decompress4x_main_loop_x86 is an x86 assembler implementation
 146  // of Decompress1X when tablelog > 8.
 147  //
 148  //go:noescape
 149  func decompress1x_main_loop_amd64(ctx *decompress1xContext)
 150  
 151  // decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation
 152  // of Decompress1X when tablelog > 8.
 153  //
 154  //go:noescape
 155  func decompress1x_main_loop_bmi2(ctx *decompress1xContext)
 156  
 157  type decompress1xContext struct {
 158  	pbr      *bitReaderShifted
 159  	peekBits uint8
 160  	out      *byte
 161  	outCap   int
 162  	tbl      *dEntrySingle
 163  	decoded  int
 164  }
 165  
 166  // Error reported by asm implementations
 167  const error_max_decoded_size_exeeded = -1
 168  
 169  // Decompress1X will decompress a 1X encoded stream.
 170  // The cap of the output buffer will be the maximum decompressed size.
 171  // The length of the supplied input must match the end of a block exactly.
 172  func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
 173  	if len(d.dt.single) == 0 {
 174  		return nil, errors.New("no table loaded")
 175  	}
 176  	var br bitReaderShifted
 177  	err := br.init(src)
 178  	if err != nil {
 179  		return dst, err
 180  	}
 181  	maxDecodedSize := cap(dst)
 182  	dst = dst[:maxDecodedSize]
 183  
 184  	const tlSize = 1 << tableLogMax
 185  	const tlMask = tlSize - 1
 186  
 187  	if maxDecodedSize >= 4 {
 188  		ctx := decompress1xContext{
 189  			pbr:      &br,
 190  			out:      &dst[0],
 191  			outCap:   maxDecodedSize,
 192  			peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
 193  			tbl:      &d.dt.single[0],
 194  		}
 195  
 196  		if cpuinfo.HasBMI2() {
 197  			decompress1x_main_loop_bmi2(&ctx)
 198  		} else {
 199  			decompress1x_main_loop_amd64(&ctx)
 200  		}
 201  		if ctx.decoded == error_max_decoded_size_exeeded {
 202  			return nil, ErrMaxDecodedSizeExceeded
 203  		}
 204  
 205  		dst = dst[:ctx.decoded]
 206  	}
 207  
 208  	// br < 8, so uint8 is fine
 209  	bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
 210  	for bitsLeft > 0 {
 211  		br.fill()
 212  		if len(dst) >= maxDecodedSize {
 213  			br.close()
 214  			return nil, ErrMaxDecodedSizeExceeded
 215  		}
 216  		v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
 217  		nBits := uint8(v.entry)
 218  		br.advance(nBits)
 219  		bitsLeft -= nBits
 220  		dst = append(dst, uint8(v.entry>>8))
 221  	}
 222  	return dst, br.close()
 223  }
 224