checksum_unsafe.go raw

   1  // Copyright 2023 The gVisor Authors.
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //     http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package checksum
  16  
  17  import (
  18  	"encoding/binary"
  19  	"math/bits"
  20  	"unsafe"
  21  )
  22  
  23  // Note: odd indicates whether initial is a partial checksum over an odd number
  24  // of bytes.
  25  func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) {
  26  	// Use a larger-than-uint16 accumulator to benefit from parallel summation
  27  	// as described in RFC 1071 1.2.C.
  28  	acc := uint64(initial)
  29  
  30  	// Handle an odd number of previously-summed bytes, and get the return
  31  	// value for odd.
  32  	if odd {
  33  		acc += uint64(buf[0])
  34  		buf = buf[1:]
  35  	}
  36  	odd = len(buf)&1 != 0
  37  
  38  	// Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case
  39  	// smaller bufs.
  40  	if len(buf) < 8 {
  41  		if len(buf) >= 4 {
  42  			acc += (uint64(buf[0]) << 8) + uint64(buf[1])
  43  			acc += (uint64(buf[2]) << 8) + uint64(buf[3])
  44  			buf = buf[4:]
  45  		}
  46  		if len(buf) >= 2 {
  47  			acc += (uint64(buf[0]) << 8) + uint64(buf[1])
  48  			buf = buf[2:]
  49  		}
  50  		if len(buf) >= 1 {
  51  			acc += uint64(buf[0]) << 8
  52  			// buf = buf[1:] is skipped because it's unused and nogo will
  53  			// complain.
  54  		}
  55  		return reduce(acc), odd
  56  	}
  57  
  58  	// On little-endian architectures, multi-byte loads from buf will load
  59  	// bytes in the wrong order. Rather than byte-swap after each load (slow),
  60  	// we byte-swap the accumulator before summing any bytes and byte-swap it
  61  	// back before returning, which still produces the correct result as
  62  	// described in RFC 1071 1.2.B "Byte Order Independence".
  63  	//
  64  	// acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We
  65  	// preserve this property by byte-swapping only the lower 32 bits of acc,
  66  	// so that additions to acc performed during alignment can't overflow.
  67  	acc = uint64(bswapIfLittleEndian32(uint32(acc)))
  68  
  69  	// Align &buf[0] to an 8-byte boundary.
  70  	bswapped := false
  71  	if sliceAddr(buf)&1 != 0 {
  72  		// Compute the rest of the partial checksum with bytes swapped, and
  73  		// swap back before returning; see the last paragraph of
  74  		// RFC 1071 1.2.B.
  75  		acc = uint64(bits.ReverseBytes32(uint32(acc)))
  76  		bswapped = true
  77  		// No `<< 8` here due to the byte swap we just did.
  78  		acc += uint64(bswapIfLittleEndian16(uint16(buf[0])))
  79  		buf = buf[1:]
  80  	}
  81  	if sliceAddr(buf)&2 != 0 {
  82  		acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0])))
  83  		buf = buf[2:]
  84  	}
  85  	if sliceAddr(buf)&4 != 0 {
  86  		acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0])))
  87  		buf = buf[4:]
  88  	}
  89  
  90  	// Sum 64 bytes at a time. Beyond this point, additions to acc may
  91  	// overflow, so we have to handle carrying.
  92  	for len(buf) >= 64 {
  93  		var carry uint64
  94  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
  95  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
  96  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
  97  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
  98  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry)
  99  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry)
 100  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry)
 101  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry)
 102  		acc, _ = bits.Add64(acc, 0, carry)
 103  		buf = buf[64:]
 104  	}
 105  
 106  	// Sum the remaining 0-63 bytes.
 107  	if len(buf) >= 32 {
 108  		var carry uint64
 109  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
 110  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
 111  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
 112  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
 113  		acc, _ = bits.Add64(acc, 0, carry)
 114  		buf = buf[32:]
 115  	}
 116  	if len(buf) >= 16 {
 117  		var carry uint64
 118  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
 119  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
 120  		acc, _ = bits.Add64(acc, 0, carry)
 121  		buf = buf[16:]
 122  	}
 123  	if len(buf) >= 8 {
 124  		var carry uint64
 125  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
 126  		acc, _ = bits.Add64(acc, 0, carry)
 127  		buf = buf[8:]
 128  	}
 129  	if len(buf) >= 4 {
 130  		var carry uint64
 131  		acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0)
 132  		acc, _ = bits.Add64(acc, 0, carry)
 133  		buf = buf[4:]
 134  	}
 135  	if len(buf) >= 2 {
 136  		var carry uint64
 137  		acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0)
 138  		acc, _ = bits.Add64(acc, 0, carry)
 139  		buf = buf[2:]
 140  	}
 141  	if len(buf) >= 1 {
 142  		// bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8).
 143  		var carry uint64
 144  		acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0)
 145  		acc, _ = bits.Add64(acc, 0, carry)
 146  		// buf = buf[1:] is skipped because it's unused and nogo will complain.
 147  	}
 148  
 149  	// Reduce the checksum to 16 bits and undo byte swaps before returning.
 150  	acc16 := bswapIfLittleEndian16(reduce(acc))
 151  	if bswapped {
 152  		acc16 = bits.ReverseBytes16(acc16)
 153  	}
 154  	return acc16, odd
 155  }
 156  
 157  func reduce(acc uint64) uint16 {
 158  	// Ideally we would do:
 159  	//   return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc)
 160  	// for more instruction-level parallelism; however, there is no
 161  	// bits.Add16().
 162  	acc = (acc >> 32) + (acc & 0xffff_ffff)  // at most 0x1_ffff_fffe
 163  	acc32 := uint32(acc>>32 + acc)           // at most 0xffff_ffff
 164  	acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe
 165  	return uint16(acc32>>16 + acc32)         // at most 0xffff
 166  }
 167  
 168  func bswapIfLittleEndian32(val uint32) uint32 {
 169  	return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:])
 170  }
 171  
 172  func bswapIfLittleEndian16(val uint16) uint16 {
 173  	return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
 174  }
 175  
 176  func bswapIfBigEndian16(val uint16) uint16 {
 177  	return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
 178  }
 179  
 180  func sliceAddr(buf []byte) uintptr {
 181  	return uintptr(unsafe.Pointer(unsafe.SliceData(buf)))
 182  }
 183