checksum_unsafe.go raw
1 // Copyright 2023 The gVisor Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package checksum
16
17 import (
18 "encoding/binary"
19 "math/bits"
20 "unsafe"
21 )
22
23 // Note: odd indicates whether initial is a partial checksum over an odd number
24 // of bytes.
25 func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) {
26 // Use a larger-than-uint16 accumulator to benefit from parallel summation
27 // as described in RFC 1071 1.2.C.
28 acc := uint64(initial)
29
30 // Handle an odd number of previously-summed bytes, and get the return
31 // value for odd.
32 if odd {
33 acc += uint64(buf[0])
34 buf = buf[1:]
35 }
36 odd = len(buf)&1 != 0
37
38 // Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case
39 // smaller bufs.
40 if len(buf) < 8 {
41 if len(buf) >= 4 {
42 acc += (uint64(buf[0]) << 8) + uint64(buf[1])
43 acc += (uint64(buf[2]) << 8) + uint64(buf[3])
44 buf = buf[4:]
45 }
46 if len(buf) >= 2 {
47 acc += (uint64(buf[0]) << 8) + uint64(buf[1])
48 buf = buf[2:]
49 }
50 if len(buf) >= 1 {
51 acc += uint64(buf[0]) << 8
52 // buf = buf[1:] is skipped because it's unused and nogo will
53 // complain.
54 }
55 return reduce(acc), odd
56 }
57
58 // On little-endian architectures, multi-byte loads from buf will load
59 // bytes in the wrong order. Rather than byte-swap after each load (slow),
60 // we byte-swap the accumulator before summing any bytes and byte-swap it
61 // back before returning, which still produces the correct result as
62 // described in RFC 1071 1.2.B "Byte Order Independence".
63 //
64 // acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We
65 // preserve this property by byte-swapping only the lower 32 bits of acc,
66 // so that additions to acc performed during alignment can't overflow.
67 acc = uint64(bswapIfLittleEndian32(uint32(acc)))
68
69 // Align &buf[0] to an 8-byte boundary.
70 bswapped := false
71 if sliceAddr(buf)&1 != 0 {
72 // Compute the rest of the partial checksum with bytes swapped, and
73 // swap back before returning; see the last paragraph of
74 // RFC 1071 1.2.B.
75 acc = uint64(bits.ReverseBytes32(uint32(acc)))
76 bswapped = true
77 // No `<< 8` here due to the byte swap we just did.
78 acc += uint64(bswapIfLittleEndian16(uint16(buf[0])))
79 buf = buf[1:]
80 }
81 if sliceAddr(buf)&2 != 0 {
82 acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0])))
83 buf = buf[2:]
84 }
85 if sliceAddr(buf)&4 != 0 {
86 acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0])))
87 buf = buf[4:]
88 }
89
90 // Sum 64 bytes at a time. Beyond this point, additions to acc may
91 // overflow, so we have to handle carrying.
92 for len(buf) >= 64 {
93 var carry uint64
94 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
95 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
96 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
97 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
98 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry)
99 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry)
100 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry)
101 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry)
102 acc, _ = bits.Add64(acc, 0, carry)
103 buf = buf[64:]
104 }
105
106 // Sum the remaining 0-63 bytes.
107 if len(buf) >= 32 {
108 var carry uint64
109 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
110 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
111 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
112 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
113 acc, _ = bits.Add64(acc, 0, carry)
114 buf = buf[32:]
115 }
116 if len(buf) >= 16 {
117 var carry uint64
118 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
119 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
120 acc, _ = bits.Add64(acc, 0, carry)
121 buf = buf[16:]
122 }
123 if len(buf) >= 8 {
124 var carry uint64
125 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
126 acc, _ = bits.Add64(acc, 0, carry)
127 buf = buf[8:]
128 }
129 if len(buf) >= 4 {
130 var carry uint64
131 acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0)
132 acc, _ = bits.Add64(acc, 0, carry)
133 buf = buf[4:]
134 }
135 if len(buf) >= 2 {
136 var carry uint64
137 acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0)
138 acc, _ = bits.Add64(acc, 0, carry)
139 buf = buf[2:]
140 }
141 if len(buf) >= 1 {
142 // bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8).
143 var carry uint64
144 acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0)
145 acc, _ = bits.Add64(acc, 0, carry)
146 // buf = buf[1:] is skipped because it's unused and nogo will complain.
147 }
148
149 // Reduce the checksum to 16 bits and undo byte swaps before returning.
150 acc16 := bswapIfLittleEndian16(reduce(acc))
151 if bswapped {
152 acc16 = bits.ReverseBytes16(acc16)
153 }
154 return acc16, odd
155 }
156
157 func reduce(acc uint64) uint16 {
158 // Ideally we would do:
159 // return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc)
160 // for more instruction-level parallelism; however, there is no
161 // bits.Add16().
162 acc = (acc >> 32) + (acc & 0xffff_ffff) // at most 0x1_ffff_fffe
163 acc32 := uint32(acc>>32 + acc) // at most 0xffff_ffff
164 acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe
165 return uint16(acc32>>16 + acc32) // at most 0xffff
166 }
167
168 func bswapIfLittleEndian32(val uint32) uint32 {
169 return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:])
170 }
171
172 func bswapIfLittleEndian16(val uint16) uint16 {
173 return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
174 }
175
176 func bswapIfBigEndian16(val uint16) uint16 {
177 return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
178 }
179
180 func sliceAddr(buf []byte) uintptr {
181 return uintptr(unsafe.Pointer(unsafe.SliceData(buf)))
182 }
183