bytebuf.go raw

   1  // Copyright 2019+ Klaus Post. All rights reserved.
   2  // License information can be found in the LICENSE file.
   3  // Based on work by Yann Collet, released under BSD License.
   4  
   5  package zstd
   6  
   7  import (
   8  	"fmt"
   9  	"io"
  10  )
  11  
  12  type byteBuffer interface {
  13  	// Read up to 8 bytes.
  14  	// Returns io.ErrUnexpectedEOF if this cannot be satisfied.
  15  	readSmall(n int) ([]byte, error)
  16  
  17  	// Read >8 bytes.
  18  	// MAY use the destination slice.
  19  	readBig(n int, dst []byte) ([]byte, error)
  20  
  21  	// Read a single byte.
  22  	readByte() (byte, error)
  23  
  24  	// Skip n bytes.
  25  	skipN(n int64) error
  26  }
  27  
  28  // in-memory buffer
  29  type byteBuf []byte
  30  
  31  func (b *byteBuf) readSmall(n int) ([]byte, error) {
  32  	if debugAsserts && n > 8 {
  33  		panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
  34  	}
  35  	bb := *b
  36  	if len(bb) < n {
  37  		return nil, io.ErrUnexpectedEOF
  38  	}
  39  	r := bb[:n]
  40  	*b = bb[n:]
  41  	return r, nil
  42  }
  43  
  44  func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
  45  	bb := *b
  46  	if len(bb) < n {
  47  		return nil, io.ErrUnexpectedEOF
  48  	}
  49  	r := bb[:n]
  50  	*b = bb[n:]
  51  	return r, nil
  52  }
  53  
  54  func (b *byteBuf) readByte() (byte, error) {
  55  	bb := *b
  56  	if len(bb) < 1 {
  57  		return 0, io.ErrUnexpectedEOF
  58  	}
  59  	r := bb[0]
  60  	*b = bb[1:]
  61  	return r, nil
  62  }
  63  
  64  func (b *byteBuf) skipN(n int64) error {
  65  	bb := *b
  66  	if n < 0 {
  67  		return fmt.Errorf("negative skip (%d) requested", n)
  68  	}
  69  	if int64(len(bb)) < n {
  70  		return io.ErrUnexpectedEOF
  71  	}
  72  	*b = bb[n:]
  73  	return nil
  74  }
  75  
  76  // wrapper around a reader.
  77  type readerWrapper struct {
  78  	r   io.Reader
  79  	tmp [8]byte
  80  }
  81  
  82  func (r *readerWrapper) readSmall(n int) ([]byte, error) {
  83  	if debugAsserts && n > 8 {
  84  		panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
  85  	}
  86  	n2, err := io.ReadFull(r.r, r.tmp[:n])
  87  	// We only really care about the actual bytes read.
  88  	if err != nil {
  89  		if err == io.EOF {
  90  			return nil, io.ErrUnexpectedEOF
  91  		}
  92  		if debugDecoder {
  93  			println("readSmall: got", n2, "want", n, "err", err)
  94  		}
  95  		return nil, err
  96  	}
  97  	return r.tmp[:n], nil
  98  }
  99  
 100  func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
 101  	if cap(dst) < n {
 102  		dst = make([]byte, n)
 103  	}
 104  	n2, err := io.ReadFull(r.r, dst[:n])
 105  	if err == io.EOF && n > 0 {
 106  		err = io.ErrUnexpectedEOF
 107  	}
 108  	return dst[:n2], err
 109  }
 110  
 111  func (r *readerWrapper) readByte() (byte, error) {
 112  	n2, err := io.ReadFull(r.r, r.tmp[:1])
 113  	if err != nil {
 114  		if err == io.EOF {
 115  			err = io.ErrUnexpectedEOF
 116  		}
 117  		return 0, err
 118  	}
 119  	if n2 != 1 {
 120  		return 0, io.ErrUnexpectedEOF
 121  	}
 122  	return r.tmp[0], nil
 123  }
 124  
 125  func (r *readerWrapper) skipN(n int64) error {
 126  	n2, err := io.CopyN(io.Discard, r.r, n)
 127  	if n2 != n {
 128  		err = io.ErrUnexpectedEOF
 129  	}
 130  	return err
 131  }
 132