ctr_amd64_asm.mx raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package main
   6  
   7  import (
   8  	"fmt"
   9  	"sync"
  10  
  11  	. "github.com/mmcloughlin/avo/build"
  12  	. "github.com/mmcloughlin/avo/operand"
  13  	. "github.com/mmcloughlin/avo/reg"
  14  )
  15  
  16  //go:generate go run . -out ../../ctr_amd64.s
  17  
  18  func main() {
  19  	Package("crypto/aes")
  20  	ConstraintExpr("!purego")
  21  
  22  	ctrBlocks(1)
  23  	ctrBlocks(2)
  24  	ctrBlocks(4)
  25  	ctrBlocks(8)
  26  
  27  	Generate()
  28  }
  29  
  30  func ctrBlocks(numBlocks int) {
  31  	Implement(fmt.Sprintf("ctrBlocks%dAsm", numBlocks))
  32  
  33  	rounds := Load(Param("nr"), GP64())
  34  	xk := Load(Param("xk"), GP64())
  35  	dst := Load(Param("dst"), GP64())
  36  	src := Load(Param("src"), GP64())
  37  	ivlo := Load(Param("ivlo"), GP64())
  38  	ivhi := Load(Param("ivhi"), GP64())
  39  
  40  	bswap := XMM()
  41  	MOVOU(bswapMask(), bswap)
  42  
  43  	blocks := []VecVirtual{:0:numBlocks}
  44  
  45  	// Lay out counter block plaintext.
  46  	for i := 0; i < numBlocks; i++ {
  47  		x := XMM()
  48  		blocks = append(blocks, x)
  49  
  50  		MOVQ(ivlo, x)
  51  		PINSRQ(Imm(1), ivhi, x)
  52  		PSHUFB(bswap, x)
  53  		if i < numBlocks-1 {
  54  			ADDQ(Imm(1), ivlo)
  55  			ADCQ(Imm(0), ivhi)
  56  		}
  57  	}
  58  
  59  	// Initial key add.
  60  	aesRoundStart(blocks, Mem{Base: xk})
  61  	ADDQ(Imm(16), xk)
  62  
  63  	// Branch based on the number of rounds.
  64  	SUBQ(Imm(12), rounds)
  65  	JE(LabelRef("enc192"))
  66  	JB(LabelRef("enc128"))
  67  
  68  	// Two extra rounds for 256-bit keys.
  69  	aesRound(blocks, Mem{Base: xk})
  70  	aesRound(blocks, Mem{Base: xk}.Offset(16))
  71  	ADDQ(Imm(32), xk)
  72  
  73  	// Two extra rounds for 192-bit keys.
  74  	Label("enc192")
  75  	aesRound(blocks, Mem{Base: xk})
  76  	aesRound(blocks, Mem{Base: xk}.Offset(16))
  77  	ADDQ(Imm(32), xk)
  78  
  79  	// 10 rounds for 128-bit keys (with special handling for the final round).
  80  	Label("enc128")
  81  	for i := 0; i < 9; i++ {
  82  		aesRound(blocks, Mem{Base: xk}.Offset(16*i))
  83  	}
  84  	aesRoundLast(blocks, Mem{Base: xk}.Offset(16*9))
  85  
  86  	// XOR state with src and write back to dst.
  87  	for i, b := range blocks {
  88  		x := XMM()
  89  
  90  		MOVUPS(Mem{Base: src}.Offset(16*i), x)
  91  		PXOR(b, x)
  92  		MOVUPS(x, Mem{Base: dst}.Offset(16*i))
  93  	}
  94  
  95  	RET()
  96  }
  97  
  98  func aesRoundStart(blocks []VecVirtual, k Mem) {
  99  	x := XMM()
 100  	MOVUPS(k, x)
 101  	for _, b := range blocks {
 102  		PXOR(x, b)
 103  	}
 104  }
 105  
 106  func aesRound(blocks []VecVirtual, k Mem) {
 107  	x := XMM()
 108  	MOVUPS(k, x)
 109  	for _, b := range blocks {
 110  		AESENC(x, b)
 111  	}
 112  }
 113  
 114  func aesRoundLast(blocks []VecVirtual, k Mem) {
 115  	x := XMM()
 116  	MOVUPS(k, x)
 117  	for _, b := range blocks {
 118  		AESENCLAST(x, b)
 119  	}
 120  }
 121  
 122  var bswapMask = sync.OnceValue(func() Mem {
 123  	bswapMask := GLOBL("bswapMask", NOPTR|RODATA)
 124  	DATA(0x00, U64(0x08090a0b0c0d0e0f))
 125  	DATA(0x08, U64(0x0001020304050607))
 126  	return bswapMask
 127  })
 128