checksum.go raw

   1  package tun
   2  
   3  import (
   4  	"encoding/binary"
   5  	"math/bits"
   6  )
   7  
   8  // TODO: Explore SIMD and/or other assembly optimizations.
   9  func checksumNoFold(b []byte, initial uint64) uint64 {
  10  	tmp := make([]byte, 8)
  11  	binary.NativeEndian.PutUint64(tmp, initial)
  12  	ac := binary.BigEndian.Uint64(tmp)
  13  	var carry uint64
  14  
  15  	for len(b) >= 128 {
  16  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
  17  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
  18  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
  19  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
  20  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
  21  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
  22  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
  23  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
  24  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
  25  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
  26  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
  27  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
  28  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
  29  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
  30  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
  31  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
  32  		ac += carry
  33  		b = b[128:]
  34  	}
  35  	if len(b) >= 64 {
  36  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
  37  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
  38  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
  39  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
  40  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
  41  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
  42  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
  43  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
  44  		ac += carry
  45  		b = b[64:]
  46  	}
  47  	if len(b) >= 32 {
  48  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
  49  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
  50  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
  51  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
  52  		ac += carry
  53  		b = b[32:]
  54  	}
  55  	if len(b) >= 16 {
  56  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
  57  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
  58  		ac += carry
  59  		b = b[16:]
  60  	}
  61  	if len(b) >= 8 {
  62  		ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
  63  		ac += carry
  64  		b = b[8:]
  65  	}
  66  	if len(b) >= 4 {
  67  		ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
  68  		ac += carry
  69  		b = b[4:]
  70  	}
  71  	if len(b) >= 2 {
  72  		ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
  73  		ac += carry
  74  		b = b[2:]
  75  	}
  76  	if len(b) == 1 {
  77  		tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
  78  		ac, carry = bits.Add64(ac, uint64(tmp), 0)
  79  		ac += carry
  80  	}
  81  
  82  	binary.NativeEndian.PutUint64(tmp, ac)
  83  	return binary.BigEndian.Uint64(tmp)
  84  }
  85  
  86  func checksum(b []byte, initial uint64) uint16 {
  87  	ac := checksumNoFold(b, initial)
  88  	ac = (ac >> 16) + (ac & 0xffff)
  89  	ac = (ac >> 16) + (ac & 0xffff)
  90  	ac = (ac >> 16) + (ac & 0xffff)
  91  	ac = (ac >> 16) + (ac & 0xffff)
  92  	return uint16(ac)
  93  }
  94  
  95  func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
  96  	sum := checksumNoFold(srcAddr, 0)
  97  	sum = checksumNoFold(dstAddr, sum)
  98  	sum = checksumNoFold([]byte{0, protocol}, sum)
  99  	tmp := make([]byte, 2)
 100  	binary.BigEndian.PutUint16(tmp, totalLen)
 101  	return checksumNoFold(tmp, sum)
 102  }
 103