ctr_arm64_gen.mx raw

   1  // Copyright 2023 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  //go:build ignore
   6  
   7  // Generate Go assembly for XORing CTR output to n blocks at once with one key.
   8  package main
   9  
  10  import (
  11  	"fmt"
  12  	"os"
  13  	"bytes"
  14  	"text/template"
  15  )
  16  
  17  // First registers in their groups.
  18  const (
  19  	blockOffset    = 0
  20  	roundKeyOffset = 8
  21  	dstOffset      = 23
  22  )
  23  
  24  var tmplArm64Str = `
  25  // Code generated by ctr_arm64_gen.go. DO NOT EDIT.
  26  
  27  //go:build !purego
  28  
  29  #include "textflag.h"
  30  
  31  #define NR R9
  32  #define XK R10
  33  #define DST R11
  34  #define SRC R12
  35  #define IV_LOW_LE R16
  36  #define IV_HIGH_LE R17
  37  #define IV_LOW_BE R19
  38  #define IV_HIGH_BE R20
  39  
  40  // V0.B16 - V7.B16 are for blocks (<=8). See BLOCK_OFFSET.
  41  // V8.B16 - V22.B16 are for <=15 round keys (<=15). See ROUND_KEY_OFFSET.
  42  // V23.B16 - V30.B16 are for destinations (<=8). See DST_OFFSET.
  43  
  44  {{define "load_keys"}}
  45  	{{- range regs_batches (round_key_reg $.FirstKey) $.NKeys }}
  46  		VLD1.P {{ .Size }}(XK), [{{ .Regs }}]
  47  	{{- end }}
  48  {{ end }}
  49  
  50  {{define "enc"}}
  51  	{{ range $i := xrange $.N -}}
  52  		AESE V{{ round_key_reg $.Key}}.B16, V{{ block_reg $i }}.B16
  53  		{{- if $.WithMc }}
  54  			AESMC V{{ block_reg $i }}.B16, V{{ block_reg $i }}.B16
  55  		{{- end }}
  56  	{{ end }}
  57  {{ end }}
  58  
  59  {{ range $N := $.Sizes }}
  60  // func ctrBlocks{{$N}}Asm(nr int, xk *[60]uint32, dst *[{{$N}}*16]byte, src *[{{$N}}*16]byte, ivlo uint64, ivhi uint64)
  61  TEXT ·ctrBlocks{{ $N }}Asm(SB),NOSPLIT,$0
  62  	MOVD nr+0(FP), NR
  63  	MOVD xk+8(FP), XK
  64  	MOVD dst+16(FP), DST
  65  	MOVD src+24(FP), SRC
  66  	MOVD ivlo+32(FP), IV_LOW_LE
  67  	MOVD ivhi+40(FP), IV_HIGH_LE
  68  
  69  	{{/* Prepare plain from IV and blockIndex. */}}
  70  
  71  	{{/* Copy to plaintext registers. */}}
  72  	{{ range $i := xrange $N }}
  73  		REV IV_LOW_LE, IV_LOW_BE
  74  		REV IV_HIGH_LE, IV_HIGH_BE
  75  		{{- /* https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/MOV--vector--from-general- */}}
  76  		VMOV IV_LOW_BE, V{{ block_reg $i }}.D[1]
  77  		VMOV IV_HIGH_BE, V{{ block_reg $i }}.D[0]
  78  		{{- if ne (add $i 1) $N }}
  79  			ADDS $1, IV_LOW_LE
  80  			ADC $0, IV_HIGH_LE
  81  		{{ end }}
  82  	{{ end }}
  83  
  84  	{{/* Num rounds branching. */}}
  85  	CMP $12, NR
  86  	BLT Lenc128
  87  	BEQ Lenc192
  88  
  89  	{{/* 2 extra rounds for 256-bit keys. */}}
  90  	Lenc256:
  91  	{{- template "load_keys" (load_keys_args 0 2) }}
  92  	{{- template "enc" (enc_args 0 $N true) }}
  93  	{{- template "enc" (enc_args 1 $N true) }}
  94  
  95  	{{/* 2 extra rounds for 192-bit keys. */}}
  96  	Lenc192:
  97  	{{- template "load_keys" (load_keys_args 2 2) }}
  98  	{{- template "enc" (enc_args 2 $N true) }}
  99  	{{- template "enc" (enc_args 3 $N true) }}
 100  
 101  	{{/* 10 rounds for 128-bit (with special handling for final). */}}
 102  	Lenc128:
 103  	{{- template "load_keys" (load_keys_args 4 11) }}
 104  	{{- range $r := xrange 9 }}
 105  		{{- template "enc" (enc_args (add $r 4) $N true) }}
 106  	{{ end }}
 107  	{{ template "enc" (enc_args 13 $N false) }}
 108  
 109  	{{/* We need to XOR blocks with the last round key (key 14, register V22). */}}
 110  	{{ range $i := xrange $N }}
 111  		VEOR V{{ block_reg $i }}.B16, V{{ round_key_reg 14 }}.B16, V{{ block_reg $i }}.B16
 112  	{{- end }}
 113  
 114  	{{/* XOR results to destination. */}}
 115  	{{- range regs_batches $.DstOffset $N }}
 116  		VLD1.P {{ .Size }}(SRC), [{{ .Regs }}]
 117  	{{- end }}
 118  	{{- range $i := xrange $N }}
 119  		VEOR V{{ add $.DstOffset $i }}.B16, V{{ block_reg $i }}.B16, V{{ add $.DstOffset $i }}.B16
 120  	{{- end }}
 121  	{{- range regs_batches $.DstOffset $N }}
 122  		VST1.P [{{ .Regs }}], {{ .Size }}(DST)
 123  	{{- end }}
 124  
 125  	RET
 126  {{ end }}
 127  `
 128  
 129  func main() {
 130  	type Params struct {
 131  		DstOffset int
 132  		Sizes     []int
 133  	}
 134  
 135  	params := Params{
 136  		DstOffset: dstOffset,
 137  		Sizes:     []int{1, 2, 4, 8},
 138  	}
 139  
 140  	type RegsBatch struct {
 141  		Size int
 142  		Regs string // Comma-separated list of registers.
 143  	}
 144  
 145  	type LoadKeysArgs struct {
 146  		FirstKey int
 147  		NKeys    int
 148  	}
 149  
 150  	type EncArgs struct {
 151  		Key    int
 152  		N      int
 153  		WithMc bool
 154  	}
 155  
 156  	funcs := template.FuncMap{
 157  		"add": func(a, b int) int {
 158  			return a + b
 159  		},
 160  		"xrange": func(n int) []int {
 161  			result := []int{:n}
 162  			for i := 0; i < n; i++ {
 163  				result[i] = i
 164  			}
 165  			return result
 166  		},
 167  		"block_reg": func(block int) int {
 168  			return blockOffset + block
 169  		},
 170  		"round_key_reg": func(key int) int {
 171  			return roundKeyOffset + key
 172  		},
 173  		"regs_batches": func(firstReg, nregs int) []RegsBatch {
 174  			result := []RegsBatch{:0}
 175  			for nregs != 0 {
 176  				batch := 4
 177  				if nregs < batch {
 178  					batch = nregs
 179  				}
 180  				regsList := [][]byte{:0:batch}
 181  				for j := firstReg; j < firstReg+batch; j++ {
 182  					regsList = append(regsList, fmt.Sprintf("V%d.B16", j))
 183  				}
 184  				result = append(result, RegsBatch{
 185  					Size: 16 * batch,
 186  					Regs: bytes.Join(regsList, ", "),
 187  				})
 188  				nregs -= batch
 189  				firstReg += batch
 190  			}
 191  			return result
 192  		},
 193  		"enc_args": func(key, n int, withMc bool) EncArgs {
 194  			return EncArgs{
 195  				Key:    key,
 196  				N:      n,
 197  				WithMc: withMc,
 198  			}
 199  		},
 200  		"load_keys_args": func(firstKey, nkeys int) LoadKeysArgs {
 201  			return LoadKeysArgs{
 202  				FirstKey: firstKey,
 203  				NKeys:    nkeys,
 204  			}
 205  		},
 206  	}
 207  
 208  	var tmpl = template.Must(template.New("ctr_arm64").Funcs(funcs).Parse(tmplArm64Str))
 209  
 210  	if err := tmpl.Execute(os.Stdout, params); err != nil {
 211  		panic(err)
 212  	}
 213  }
 214