fe_amd64_asm.mx raw
1 // Copyright (c) 2021 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package main
6
7 import (
8 "fmt"
9
10 . "github.com/mmcloughlin/avo/build"
11 . "github.com/mmcloughlin/avo/gotypes"
12 . "github.com/mmcloughlin/avo/operand"
13 . "github.com/mmcloughlin/avo/reg"
14 )
15
16 //go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field
17
18 func main() {
19 Package("crypto/internal/fips140/edwards25519/field")
20 ConstraintExpr("!purego")
21 feMul()
22 feSquare()
23 Generate()
24 }
25
26 type namedComponent struct {
27 Component
28 name []byte
29 }
30
31 func (c namedComponent) String() string { return c.name }
32
33 type uint128 struct {
34 name []byte
35 hi, lo GPVirtual
36 }
37
38 func (c uint128) String() string { return c.name }
39
40 func feSquare() {
41 TEXT("feSquare", NOSPLIT, "func(out, a *Element)")
42 Doc("feSquare sets out = a * a. It works like feSquareGeneric.")
43 Pragma("noescape")
44
45 a := Dereference(Param("a"))
46 l0 := namedComponent{a.Field("l0"), "l0"}
47 l1 := namedComponent{a.Field("l1"), "l1"}
48 l2 := namedComponent{a.Field("l2"), "l2"}
49 l3 := namedComponent{a.Field("l3"), "l3"}
50 l4 := namedComponent{a.Field("l4"), "l4"}
51
52 // r0 = l0×l0 + 19×2×(l1×l4 + l2×l3)
53 r0 := uint128{"r0", GP64(), GP64()}
54 mul64(r0, 1, l0, l0)
55 addMul64(r0, 38, l1, l4)
56 addMul64(r0, 38, l2, l3)
57
58 // r1 = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
59 r1 := uint128{"r1", GP64(), GP64()}
60 mul64(r1, 2, l0, l1)
61 addMul64(r1, 38, l2, l4)
62 addMul64(r1, 19, l3, l3)
63
64 // r2 = = 2×l0×l2 + l1×l1 + 19×2×l3×l4
65 r2 := uint128{"r2", GP64(), GP64()}
66 mul64(r2, 2, l0, l2)
67 addMul64(r2, 1, l1, l1)
68 addMul64(r2, 38, l3, l4)
69
70 // r3 = = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
71 r3 := uint128{"r3", GP64(), GP64()}
72 mul64(r3, 2, l0, l3)
73 addMul64(r3, 2, l1, l2)
74 addMul64(r3, 19, l4, l4)
75
76 // r4 = = 2×l0×l4 + 2×l1×l3 + l2×l2
77 r4 := uint128{"r4", GP64(), GP64()}
78 mul64(r4, 2, l0, l4)
79 addMul64(r4, 2, l1, l3)
80 addMul64(r4, 1, l2, l2)
81
82 Comment("First reduction chain")
83 maskLow51Bits := GP64()
84 MOVQ(Imm((1<<51)-1), maskLow51Bits)
85 c0, r0lo := shiftRightBy51(&r0)
86 c1, r1lo := shiftRightBy51(&r1)
87 c2, r2lo := shiftRightBy51(&r2)
88 c3, r3lo := shiftRightBy51(&r3)
89 c4, r4lo := shiftRightBy51(&r4)
90 maskAndAdd(r0lo, maskLow51Bits, c4, 19)
91 maskAndAdd(r1lo, maskLow51Bits, c0, 1)
92 maskAndAdd(r2lo, maskLow51Bits, c1, 1)
93 maskAndAdd(r3lo, maskLow51Bits, c2, 1)
94 maskAndAdd(r4lo, maskLow51Bits, c3, 1)
95
96 Comment("Second reduction chain (carryPropagate)")
97 // c0 = r0 >> 51
98 MOVQ(r0lo, c0)
99 SHRQ(Imm(51), c0)
100 // c1 = r1 >> 51
101 MOVQ(r1lo, c1)
102 SHRQ(Imm(51), c1)
103 // c2 = r2 >> 51
104 MOVQ(r2lo, c2)
105 SHRQ(Imm(51), c2)
106 // c3 = r3 >> 51
107 MOVQ(r3lo, c3)
108 SHRQ(Imm(51), c3)
109 // c4 = r4 >> 51
110 MOVQ(r4lo, c4)
111 SHRQ(Imm(51), c4)
112 maskAndAdd(r0lo, maskLow51Bits, c4, 19)
113 maskAndAdd(r1lo, maskLow51Bits, c0, 1)
114 maskAndAdd(r2lo, maskLow51Bits, c1, 1)
115 maskAndAdd(r3lo, maskLow51Bits, c2, 1)
116 maskAndAdd(r4lo, maskLow51Bits, c3, 1)
117
118 Comment("Store output")
119 out := Dereference(Param("out"))
120 Store(r0lo, out.Field("l0"))
121 Store(r1lo, out.Field("l1"))
122 Store(r2lo, out.Field("l2"))
123 Store(r3lo, out.Field("l3"))
124 Store(r4lo, out.Field("l4"))
125
126 RET()
127 }
128
129 func feMul() {
130 TEXT("feMul", NOSPLIT, "func(out, a, b *Element)")
131 Doc("feMul sets out = a * b. It works like feMulGeneric.")
132 Pragma("noescape")
133
134 a := Dereference(Param("a"))
135 a0 := namedComponent{a.Field("l0"), "a0"}
136 a1 := namedComponent{a.Field("l1"), "a1"}
137 a2 := namedComponent{a.Field("l2"), "a2"}
138 a3 := namedComponent{a.Field("l3"), "a3"}
139 a4 := namedComponent{a.Field("l4"), "a4"}
140
141 b := Dereference(Param("b"))
142 b0 := namedComponent{b.Field("l0"), "b0"}
143 b1 := namedComponent{b.Field("l1"), "b1"}
144 b2 := namedComponent{b.Field("l2"), "b2"}
145 b3 := namedComponent{b.Field("l3"), "b3"}
146 b4 := namedComponent{b.Field("l4"), "b4"}
147
148 // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
149 r0 := uint128{"r0", GP64(), GP64()}
150 mul64(r0, 1, a0, b0)
151 addMul64(r0, 19, a1, b4)
152 addMul64(r0, 19, a2, b3)
153 addMul64(r0, 19, a3, b2)
154 addMul64(r0, 19, a4, b1)
155
156 // r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
157 r1 := uint128{"r1", GP64(), GP64()}
158 mul64(r1, 1, a0, b1)
159 addMul64(r1, 1, a1, b0)
160 addMul64(r1, 19, a2, b4)
161 addMul64(r1, 19, a3, b3)
162 addMul64(r1, 19, a4, b2)
163
164 // r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
165 r2 := uint128{"r2", GP64(), GP64()}
166 mul64(r2, 1, a0, b2)
167 addMul64(r2, 1, a1, b1)
168 addMul64(r2, 1, a2, b0)
169 addMul64(r2, 19, a3, b4)
170 addMul64(r2, 19, a4, b3)
171
172 // r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
173 r3 := uint128{"r3", GP64(), GP64()}
174 mul64(r3, 1, a0, b3)
175 addMul64(r3, 1, a1, b2)
176 addMul64(r3, 1, a2, b1)
177 addMul64(r3, 1, a3, b0)
178 addMul64(r3, 19, a4, b4)
179
180 // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
181 r4 := uint128{"r4", GP64(), GP64()}
182 mul64(r4, 1, a0, b4)
183 addMul64(r4, 1, a1, b3)
184 addMul64(r4, 1, a2, b2)
185 addMul64(r4, 1, a3, b1)
186 addMul64(r4, 1, a4, b0)
187
188 Comment("First reduction chain")
189 maskLow51Bits := GP64()
190 MOVQ(Imm((1<<51)-1), maskLow51Bits)
191 c0, r0lo := shiftRightBy51(&r0)
192 c1, r1lo := shiftRightBy51(&r1)
193 c2, r2lo := shiftRightBy51(&r2)
194 c3, r3lo := shiftRightBy51(&r3)
195 c4, r4lo := shiftRightBy51(&r4)
196 maskAndAdd(r0lo, maskLow51Bits, c4, 19)
197 maskAndAdd(r1lo, maskLow51Bits, c0, 1)
198 maskAndAdd(r2lo, maskLow51Bits, c1, 1)
199 maskAndAdd(r3lo, maskLow51Bits, c2, 1)
200 maskAndAdd(r4lo, maskLow51Bits, c3, 1)
201
202 Comment("Second reduction chain (carryPropagate)")
203 // c0 = r0 >> 51
204 MOVQ(r0lo, c0)
205 SHRQ(Imm(51), c0)
206 // c1 = r1 >> 51
207 MOVQ(r1lo, c1)
208 SHRQ(Imm(51), c1)
209 // c2 = r2 >> 51
210 MOVQ(r2lo, c2)
211 SHRQ(Imm(51), c2)
212 // c3 = r3 >> 51
213 MOVQ(r3lo, c3)
214 SHRQ(Imm(51), c3)
215 // c4 = r4 >> 51
216 MOVQ(r4lo, c4)
217 SHRQ(Imm(51), c4)
218 maskAndAdd(r0lo, maskLow51Bits, c4, 19)
219 maskAndAdd(r1lo, maskLow51Bits, c0, 1)
220 maskAndAdd(r2lo, maskLow51Bits, c1, 1)
221 maskAndAdd(r3lo, maskLow51Bits, c2, 1)
222 maskAndAdd(r4lo, maskLow51Bits, c3, 1)
223
224 Comment("Store output")
225 out := Dereference(Param("out"))
226 Store(r0lo, out.Field("l0"))
227 Store(r1lo, out.Field("l1"))
228 Store(r2lo, out.Field("l2"))
229 Store(r3lo, out.Field("l3"))
230 Store(r4lo, out.Field("l4"))
231
232 RET()
233 }
234
235 // mul64 sets r to i * aX * bX.
236 func mul64(r uint128, i int, aX, bX namedComponent) {
237 switch i {
238 case 1:
239 Comment(fmt.Sprintf("%s = %s×%s", r, aX, bX))
240 Load(aX, RAX)
241 case 2:
242 Comment(fmt.Sprintf("%s = 2×%s×%s", r, aX, bX))
243 Load(aX, RAX)
244 SHLQ(Imm(1), RAX)
245 default:
246 panic("unsupported i value")
247 }
248 MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
249 MOVQ(RAX, r.lo)
250 MOVQ(RDX, r.hi)
251 }
252
253 // addMul64 sets r to r + i * aX * bX.
254 func addMul64(r uint128, i uint64, aX, bX namedComponent) {
255 switch i {
256 case 1:
257 Comment(fmt.Sprintf("%s += %s×%s", r, aX, bX))
258 Load(aX, RAX)
259 case 2:
260 Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
261 Load(aX, RAX)
262 SHLQ(U8(1), RAX)
263 case 19:
264 Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
265 // 19 * v ==> v + (v+v*8)*2
266 tmp := Load(aX, GP64())
267 LEAQ(Mem{Base: tmp, Index: tmp, Scale: 8}, RAX)
268 LEAQ(Mem{Base: tmp, Index: RAX, Scale: 2}, RAX)
269 case 38:
270 Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
271 // 38 * v ==> (v + (v+v*8)*2) * 2
272 tmp := Load(aX, GP64())
273 LEAQ(Mem{Base: tmp, Index: tmp, Scale: 8}, RAX)
274 LEAQ(Mem{Base: tmp, Index: RAX, Scale: 2}, RAX)
275 SHLQ(U8(1), RAX)
276 default:
277 Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
278 IMUL3Q(Imm(i), Load(aX, GP64()), RAX)
279 }
280 MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
281 ADDQ(RAX, r.lo)
282 ADCQ(RDX, r.hi)
283 }
284
285 // shiftRightBy51 returns r >> 51 and r.lo.
286 //
287 // After this function is called, the uint128 may not be used anymore.
288 func shiftRightBy51(r *uint128) (out, lo GPVirtual) {
289 out = r.hi
290 lo = r.lo
291 SHLQ(Imm(64-51), r.lo, r.hi)
292 r.lo, r.hi = nil, nil // make sure the uint128 is unusable
293 return
294 }
295
296 // maskAndAdd sets r = r&mask + c*i.
297 func maskAndAdd(r, mask, c GPVirtual, i uint64) {
298 ANDQ(mask, r)
299 if i != 1 {
300 IMUL3Q(Imm(i), c, c)
301 }
302 ADDQ(c, r)
303 }
304
305 func mustAddr(c Component) Op {
306 b, err := c.Resolve()
307 if err != nil {
308 panic(err)
309 }
310 return b.Addr
311 }
312