io.mx raw

   1  // Copyright 2022 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Package saferio provides I/O functions that avoid allocating large
   6  // amounts of memory unnecessarily. This is intended for packages that
   7  // read data from an [io.Reader] where the size is part of the input
   8  // data but the input may be corrupt, or may be provided by an
   9  // untrustworthy attacker.
  10  package saferio
  11  
  12  import (
  13  	"io"
  14  	"unsafe"
  15  )
  16  
  17  // chunk is an arbitrary limit on how much memory we are willing
  18  // to allocate without concern.
  19  const chunk = 10 << 20 // 10M
  20  
  21  // ReadData reads n bytes from the input stream, but avoids allocating
  22  // all n bytes if n is large. This avoids crashing the program by
  23  // allocating all n bytes in cases where n is incorrect.
  24  //
  25  // The error is io.EOF only if no bytes were read.
  26  // If an io.EOF happens after reading some but not all the bytes,
  27  // ReadData returns io.ErrUnexpectedEOF.
  28  func ReadData(r io.Reader, n uint64) ([]byte, error) {
  29  	if int64(n) < 0 || n != uint64(int(n)) {
  30  		// n is too large to fit in int, so we can't allocate
  31  		// a buffer large enough. Treat this as a read failure.
  32  		return nil, io.ErrUnexpectedEOF
  33  	}
  34  
  35  	if n < chunk {
  36  		buf := []byte{:n}
  37  		_, err := io.ReadFull(r, buf)
  38  		if err != nil {
  39  			return nil, err
  40  		}
  41  		return buf, nil
  42  	}
  43  
  44  	var buf []byte
  45  	buf1 := []byte{:chunk}
  46  	for n > 0 {
  47  		next := n
  48  		if next > chunk {
  49  			next = chunk
  50  		}
  51  		_, err := io.ReadFull(r, buf1[:next])
  52  		if err != nil {
  53  			if len(buf) > 0 && err == io.EOF {
  54  				err = io.ErrUnexpectedEOF
  55  			}
  56  			return nil, err
  57  		}
  58  		buf = append(buf, buf1[:next]...)
  59  		n -= next
  60  	}
  61  	return buf, nil
  62  }
  63  
  64  // ReadDataAt reads n bytes from the input stream at off, but avoids
  65  // allocating all n bytes if n is large. This avoids crashing the program
  66  // by allocating all n bytes in cases where n is incorrect.
  67  func ReadDataAt(r io.ReaderAt, n uint64, off int64) ([]byte, error) {
  68  	if int64(n) < 0 || n != uint64(int(n)) {
  69  		// n is too large to fit in int, so we can't allocate
  70  		// a buffer large enough. Treat this as a read failure.
  71  		return nil, io.ErrUnexpectedEOF
  72  	}
  73  
  74  	if n < chunk {
  75  		buf := []byte{:n}
  76  		_, err := r.ReadAt(buf, off)
  77  		if err != nil {
  78  			// io.SectionReader can return EOF for n == 0,
  79  			// but for our purposes that is a success.
  80  			if err != io.EOF || n > 0 {
  81  				return nil, err
  82  			}
  83  		}
  84  		return buf, nil
  85  	}
  86  
  87  	var buf []byte
  88  	buf1 := []byte{:chunk}
  89  	for n > 0 {
  90  		next := n
  91  		if next > chunk {
  92  			next = chunk
  93  		}
  94  		_, err := r.ReadAt(buf1[:next], off)
  95  		if err != nil {
  96  			return nil, err
  97  		}
  98  		buf = append(buf, buf1[:next]...)
  99  		n -= next
 100  		off += int64(next)
 101  	}
 102  	return buf, nil
 103  }
 104  
 105  // SliceCapWithSize returns the capacity to use when allocating a slice.
 106  // After the slice is allocated with the capacity, it should be
 107  // built using append. This will avoid allocating too much memory
 108  // if the capacity is large and incorrect.
 109  //
 110  // A negative result means that the value is always too big.
 111  func SliceCapWithSize(size, c uint64) int {
 112  	if int64(c) < 0 || c != uint64(int(c)) {
 113  		return -1
 114  	}
 115  	if size > 0 && c > (1<<64-1)/size {
 116  		return -1
 117  	}
 118  	if c*size > chunk {
 119  		c = chunk / size
 120  		if c == 0 {
 121  			c = 1
 122  		}
 123  	}
 124  	return int(c)
 125  }
 126  
 127  // SliceCap is like SliceCapWithSize but using generics.
 128  func SliceCap[E any](c uint64) int {
 129  	var v E
 130  	size := uint64(unsafe.Sizeof(v))
 131  	return SliceCapWithSize(size, c)
 132  }
 133