aes_amd64.mx raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package main
   6  
   7  import (
   8  	"os"
   9  	"bytes"
  10  
  11  	. "github.com/mmcloughlin/avo/build"
  12  	"github.com/mmcloughlin/avo/ir"
  13  	. "github.com/mmcloughlin/avo/operand"
  14  	. "github.com/mmcloughlin/avo/reg"
  15  )
  16  
  17  //go:generate go run . -out ../../aes_amd64.s
  18  
  19  func main() {
  20  	Package("crypto/aes")
  21  	ConstraintExpr("!purego")
  22  	encryptBlockAsm()
  23  	decryptBlockAsm()
  24  	expandKeyAsm()
  25  	_expand_key_128()
  26  	_expand_key_192a()
  27  	_expand_key_192b()
  28  	_expand_key_256a()
  29  	_expand_key_256b()
  30  	Generate()
  31  
  32  	var internalFunctions [][]byte = [][]byte{
  33  		"·_expand_key_128<>",
  34  		"·_expand_key_192a<>",
  35  		"·_expand_key_192b<>",
  36  		"·_expand_key_256a<>",
  37  		"·_expand_key_256b<>",
  38  	}
  39  	removePeskyUnicodeDot(internalFunctions, "../../asm_amd64.s")
  40  }
  41  
  42  func encryptBlockAsm() {
  43  	Implement("encryptBlockAsm")
  44  	Attributes(NOSPLIT)
  45  	AllocLocal(0)
  46  
  47  	Load(Param("nr"), RCX)
  48  	Load(Param("xk"), RAX)
  49  	Load(Param("dst"), RDX)
  50  	Load(Param("src"), RBX)
  51  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
  52  	MOVUPS(Mem{Base: BX}.Offset(0), X0)
  53  	ADDQ(Imm(16), RAX)
  54  	PXOR(X1, X0)
  55  	SUBQ(Imm(12), RCX)
  56  	JE(LabelRef("Lenc192"))
  57  	JB(LabelRef("Lenc128"))
  58  
  59  	Label("Lenc256")
  60  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
  61  	AESENC(X1, X0)
  62  	MOVUPS(Mem{Base: AX}.Offset(16), X1)
  63  	AESENC(X1, X0)
  64  	ADDQ(Imm(32), RAX)
  65  
  66  	Label("Lenc192")
  67  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
  68  	AESENC(X1, X0)
  69  	MOVUPS(Mem{Base: AX}.Offset(16), X1)
  70  	AESENC(X1, X0)
  71  	ADDQ(Imm(32), RAX)
  72  
  73  	Label("Lenc128")
  74  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
  75  	AESENC(X1, X0)
  76  	MOVUPS(Mem{Base: AX}.Offset(16), X1)
  77  	AESENC(X1, X0)
  78  	MOVUPS(Mem{Base: AX}.Offset(32), X1)
  79  	AESENC(X1, X0)
  80  	MOVUPS(Mem{Base: AX}.Offset(48), X1)
  81  	AESENC(X1, X0)
  82  	MOVUPS(Mem{Base: AX}.Offset(64), X1)
  83  	AESENC(X1, X0)
  84  	MOVUPS(Mem{Base: AX}.Offset(80), X1)
  85  	AESENC(X1, X0)
  86  	MOVUPS(Mem{Base: AX}.Offset(96), X1)
  87  	AESENC(X1, X0)
  88  	MOVUPS(Mem{Base: AX}.Offset(112), X1)
  89  	AESENC(X1, X0)
  90  	MOVUPS(Mem{Base: AX}.Offset(128), X1)
  91  	AESENC(X1, X0)
  92  	MOVUPS(Mem{Base: AX}.Offset(144), X1)
  93  	AESENCLAST(X1, X0)
  94  	MOVUPS(X0, Mem{Base: DX}.Offset(0))
  95  	RET()
  96  }
  97  
  98  func decryptBlockAsm() {
  99  	Implement("decryptBlockAsm")
 100  	Attributes(NOSPLIT)
 101  	AllocLocal(0)
 102  
 103  	Load(Param("nr"), RCX)
 104  	Load(Param("xk"), RAX)
 105  	Load(Param("dst"), RDX)
 106  	Load(Param("src"), RBX)
 107  
 108  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
 109  	MOVUPS(Mem{Base: BX}.Offset(0), X0)
 110  	ADDQ(Imm(16), RAX)
 111  	PXOR(X1, X0)
 112  	SUBQ(Imm(12), RCX)
 113  	JE(LabelRef("Ldec192"))
 114  	JB(LabelRef("Ldec128"))
 115  
 116  	Label("Ldec256")
 117  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
 118  	AESDEC(X1, X0)
 119  	MOVUPS(Mem{Base: AX}.Offset(16), X1)
 120  	AESDEC(X1, X0)
 121  	ADDQ(Imm(32), RAX)
 122  
 123  	Label("Ldec192")
 124  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
 125  	AESDEC(X1, X0)
 126  	MOVUPS(Mem{Base: AX}.Offset(16), X1)
 127  	AESDEC(X1, X0)
 128  	ADDQ(Imm(32), RAX)
 129  
 130  	Label("Ldec128")
 131  	MOVUPS(Mem{Base: AX}.Offset(0), X1)
 132  	AESDEC(X1, X0)
 133  	MOVUPS(Mem{Base: AX}.Offset(16), X1)
 134  	AESDEC(X1, X0)
 135  	MOVUPS(Mem{Base: AX}.Offset(32), X1)
 136  	AESDEC(X1, X0)
 137  	MOVUPS(Mem{Base: AX}.Offset(48), X1)
 138  	AESDEC(X1, X0)
 139  	MOVUPS(Mem{Base: AX}.Offset(64), X1)
 140  	AESDEC(X1, X0)
 141  	MOVUPS(Mem{Base: AX}.Offset(80), X1)
 142  	AESDEC(X1, X0)
 143  	MOVUPS(Mem{Base: AX}.Offset(96), X1)
 144  	AESDEC(X1, X0)
 145  	MOVUPS(Mem{Base: AX}.Offset(112), X1)
 146  	AESDEC(X1, X0)
 147  	MOVUPS(Mem{Base: AX}.Offset(128), X1)
 148  	AESDEC(X1, X0)
 149  	MOVUPS(Mem{Base: AX}.Offset(144), X1)
 150  	AESDECLAST(X1, X0)
 151  	MOVUPS(X0, Mem{Base: DX}.Offset(0))
 152  	RET()
 153  }
 154  
 155  // Note that round keys are stored in uint128 format, not uint32
 156  func expandKeyAsm() {
 157  	Implement("expandKeyAsm")
 158  	Attributes(NOSPLIT)
 159  	AllocLocal(0)
 160  
 161  	Load(Param("nr"), RCX)
 162  	Load(Param("key"), RAX)
 163  	Load(Param("enc"), RBX)
 164  	Load(Param("dec"), RDX)
 165  
 166  	MOVUPS(Mem{Base: AX}, X0)
 167  	Comment("enc")
 168  	MOVUPS(X0, Mem{Base: BX})
 169  	ADDQ(Imm(16), RBX)
 170  	PXOR(X4, X4) // _expand_key_* expect X4 to be zero
 171  	CMPL(ECX, Imm(12))
 172  	JE(LabelRef("Lexp_enc192"))
 173  	JB(LabelRef("Lexp_enc128"))
 174  
 175  	Lexp_enc256()
 176  	Lexp_enc192()
 177  	Lexp_enc128()
 178  	Lexp_dec()
 179  	Lexp_dec_loop()
 180  }
 181  
 182  func Lexp_enc256() {
 183  	Label("Lexp_enc256")
 184  	MOVUPS(Mem{Base: AX}.Offset(16), X2)
 185  	MOVUPS(X2, Mem{Base: BX})
 186  	ADDQ(Imm(16), RBX)
 187  
 188  	var rcon uint64 = 1
 189  	for i := 0; i < 6; i++ {
 190  		AESKEYGENASSIST(Imm(rcon), X2, X1)
 191  		CALL(LabelRef("_expand_key_256a<>(SB)"))
 192  		AESKEYGENASSIST(Imm(rcon), X0, X1)
 193  		CALL(LabelRef("_expand_key_256b<>(SB)"))
 194  		rcon <<= 1
 195  	}
 196  	AESKEYGENASSIST(Imm(0x40), X2, X1)
 197  	CALL(LabelRef("_expand_key_256a<>(SB)"))
 198  	JMP(LabelRef("Lexp_dec"))
 199  }
 200  
 201  func Lexp_enc192() {
 202  	Label("Lexp_enc192")
 203  	MOVQ(Mem{Base: AX}.Offset(16), X2)
 204  
 205  	var rcon uint64 = 1
 206  	for i := 0; i < 8; i++ {
 207  		AESKEYGENASSIST(Imm(rcon), X2, X1)
 208  		if i%2 == 0 {
 209  			CALL(LabelRef("_expand_key_192a<>(SB)"))
 210  		} else {
 211  			CALL(LabelRef("_expand_key_192b<>(SB)"))
 212  		}
 213  		rcon <<= 1
 214  	}
 215  	JMP(LabelRef("Lexp_dec"))
 216  }
 217  
 218  func Lexp_enc128() {
 219  	Label("Lexp_enc128")
 220  	var rcon uint64 = 1
 221  	for i := 0; i < 8; i++ {
 222  		AESKEYGENASSIST(Imm(rcon), X0, X1)
 223  		CALL(LabelRef("_expand_key_128<>(SB)"))
 224  		rcon <<= 1
 225  	}
 226  	AESKEYGENASSIST(Imm(0x1b), X0, X1)
 227  	CALL(LabelRef("_expand_key_128<>(SB)"))
 228  	AESKEYGENASSIST(Imm(0x36), X0, X1)
 229  	CALL(LabelRef("_expand_key_128<>(SB)"))
 230  }
 231  
 232  func Lexp_dec() {
 233  	Label("Lexp_dec")
 234  	Comment("dec")
 235  	SUBQ(Imm(16), RBX)
 236  	MOVUPS(Mem{Base: BX}, X1)
 237  	MOVUPS(X1, Mem{Base: DX})
 238  	DECQ(RCX)
 239  }
 240  
 241  func Lexp_dec_loop() {
 242  	Label("Lexp_dec_loop")
 243  	MOVUPS(Mem{Base: BX}.Offset(-16), X1)
 244  	AESIMC(X1, X0)
 245  	MOVUPS(X0, Mem{Base: DX}.Offset(16))
 246  	SUBQ(Imm(16), RBX)
 247  	ADDQ(Imm(16), RDX)
 248  	DECQ(RCX)
 249  	JNZ(LabelRef("Lexp_dec_loop"))
 250  	MOVUPS(Mem{Base: BX}.Offset(-16), X0)
 251  	MOVUPS(X0, Mem{Base: DX}.Offset(16))
 252  	RET()
 253  }
 254  
 255  func _expand_key_128() {
 256  	Function("_expand_key_128<>")
 257  	Attributes(NOSPLIT)
 258  	AllocLocal(0)
 259  
 260  	PSHUFD(Imm(0xff), X1, X1)
 261  	SHUFPS(Imm(0x10), X0, X4)
 262  	PXOR(X4, X0)
 263  	SHUFPS(Imm(0x8c), X0, X4)
 264  	PXOR(X4, X0)
 265  	PXOR(X1, X0)
 266  	MOVUPS(X0, Mem{Base: BX})
 267  	ADDQ(Imm(16), RBX)
 268  	RET()
 269  }
 270  
 271  func _expand_key_192a() {
 272  	Function("_expand_key_192a<>")
 273  	Attributes(NOSPLIT)
 274  	AllocLocal(0)
 275  
 276  	PSHUFD(Imm(0x55), X1, X1)
 277  	SHUFPS(Imm(0x10), X0, X4)
 278  	PXOR(X4, X0)
 279  	SHUFPS(Imm(0x8c), X0, X4)
 280  	PXOR(X4, X0)
 281  	PXOR(X1, X0)
 282  
 283  	MOVAPS(X2, X5)
 284  	MOVAPS(X2, X6)
 285  	PSLLDQ(Imm(0x4), X5)
 286  	PSHUFD(Imm(0xff), X0, X3)
 287  	PXOR(X3, X2)
 288  	PXOR(X5, X2)
 289  
 290  	MOVAPS(X0, X1)
 291  	SHUFPS(Imm(0x44), X0, X6)
 292  	MOVUPS(X6, Mem{Base: BX})
 293  	SHUFPS(Imm(0x4e), X2, X1)
 294  	MOVUPS(X1, Mem{Base: BX}.Offset(16))
 295  	ADDQ(Imm(32), RBX)
 296  	RET()
 297  }
 298  
 299  func _expand_key_192b() {
 300  	Function("_expand_key_192b<>")
 301  	Attributes(NOSPLIT)
 302  	AllocLocal(0)
 303  
 304  	PSHUFD(Imm(0x55), X1, X1)
 305  	SHUFPS(Imm(0x10), X0, X4)
 306  	PXOR(X4, X0)
 307  	SHUFPS(Imm(0x8c), X0, X4)
 308  	PXOR(X4, X0)
 309  	PXOR(X1, X0)
 310  
 311  	MOVAPS(X2, X5)
 312  	PSLLDQ(Imm(0x4), X5)
 313  	PSHUFD(Imm(0xff), X0, X3)
 314  	PXOR(X3, X2)
 315  	PXOR(X5, X2)
 316  
 317  	MOVUPS(X0, Mem{Base: BX})
 318  	ADDQ(Imm(16), RBX)
 319  	RET()
 320  }
 321  
 322  func _expand_key_256a() {
 323  	Function("_expand_key_256a<>")
 324  	Attributes(NOSPLIT)
 325  	AllocLocal(0)
 326  
 327  	// Hack to get Avo to emit:
 328  	// 	JMP _expand_key_128<>(SB)
 329  	Instruction(&ir.Instruction{
 330  		Opcode: "JMP",
 331  		Operands: []Op{
 332  			LabelRef("_expand_key_128<>(SB)"),
 333  		},
 334  	})
 335  }
 336  
 337  func _expand_key_256b() {
 338  	Function("_expand_key_256b<>")
 339  	Attributes(NOSPLIT)
 340  	AllocLocal(0)
 341  
 342  	PSHUFD(Imm(0xaa), X1, X1)
 343  	SHUFPS(Imm(0x10), X2, X4)
 344  	PXOR(X4, X2)
 345  	SHUFPS(Imm(0x8c), X2, X4)
 346  	PXOR(X4, X2)
 347  	PXOR(X1, X2)
 348  
 349  	MOVUPS(X2, Mem{Base: BX})
 350  	ADDQ(Imm(16), RBX)
 351  	RET()
 352  }
 353  
 354  const ThatPeskyUnicodeDot = "\u00b7"
 355  
 356  // removePeskyUnicodeDot strips the dot from the relevant TEXT directives such that they
 357  // can exist as internal assembly functions
 358  //
 359  // Avo v0.6.0 does not support the generation of internal assembly functions. Go's unicode
 360  // dot tells the compiler to link a TEXT symbol to a function in the current Go package
 361  // (or another package if specified). Avo unconditionally prepends the unicode dot to all
 362  // TEXT symbols, making it impossible to emit an internal function without this hack.
 363  //
 364  // There is a pending PR to add internal functions to Avo:
 365  // https://github.com/mmcloughlin/avo/pull/443
 366  //
 367  // If merged it should allow the usage of InternalFunction("NAME") for the specified functions
 368  func removePeskyUnicodeDot(internalFunctions [][]byte, target []byte) {
 369  	bytes, err := os.ReadFile(target)
 370  	if err != nil {
 371  		panic(err)
 372  	}
 373  
 374  	content := []byte(bytes)
 375  
 376  	for _, from := range internalFunctions {
 377  		to := bytes.ReplaceAll(from, ThatPeskyUnicodeDot, "")
 378  		content = bytes.ReplaceAll(content, from, to)
 379  	}
 380  
 381  	err = os.WriteFile(target, []byte(content), 0644)
 382  	if err != nil {
 383  		panic(err)
 384  	}
 385  }
 386