hpke.mx raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package hpke
   6  
   7  import (
   8  	"crypto"
   9  	"crypto/aes"
  10  	"crypto/cipher"
  11  	"crypto/ecdh"
  12  	"crypto/hkdf"
  13  	"crypto/rand"
  14  	"errors"
  15  	"internal/byteorder"
  16  	"math/bits"
  17  
  18  	"golang.org/x/crypto/chacha20poly1305"
  19  )
  20  
  21  // testingOnlyGenerateKey is only used during testing, to provide
  22  // a fixed test key to use when checking the RFC 9180 vectors.
  23  var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
  24  
  25  type hkdfKDF struct {
  26  	hash crypto.Hash
  27  }
  28  
  29  func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) {
  30  	labeledIKM := []byte{:0:7+len(sid)+len(label)+len(inputKey)}
  31  	labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
  32  	labeledIKM = append(labeledIKM, sid...)
  33  	labeledIKM = append(labeledIKM, label...)
  34  	labeledIKM = append(labeledIKM, inputKey...)
  35  	return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
  36  }
  37  
  38  func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) {
  39  	labeledInfo := []byte{:0:2+7+len(suiteID)+len(label)+len(info)}
  40  	labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
  41  	labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
  42  	labeledInfo = append(labeledInfo, suiteID...)
  43  	labeledInfo = append(labeledInfo, label...)
  44  	labeledInfo = append(labeledInfo, info...)
  45  	return hkdf.Expand(kdf.hash.New, randomKey, labeledInfo, int(length))
  46  }
  47  
  48  // dhKEM implements the KEM specified in RFC 9180, Section 4.1.
  49  type dhKEM struct {
  50  	dh  ecdh.Curve
  51  	kdf hkdfKDF
  52  
  53  	suiteID []byte
  54  	nSecret uint16
  55  }
  56  
  57  type KemID uint16
  58  
  59  const DHKEM_X25519_HKDF_SHA256 = 0x0020
  60  
  61  var SupportedKEMs = map[uint16]struct {
  62  	curve   ecdh.Curve
  63  	hash    crypto.Hash
  64  	nSecret uint16
  65  }{
  66  	// RFC 9180 Section 7.1
  67  	DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
  68  }
  69  
  70  func newDHKem(kemID uint16) (*dhKEM, error) {
  71  	suite, ok := SupportedKEMs[kemID]
  72  	if !ok {
  73  		return nil, errors.New("unsupported suite ID")
  74  	}
  75  	return &dhKEM{
  76  		dh:      suite.curve,
  77  		kdf:     hkdfKDF{suite.hash},
  78  		suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID),
  79  		nSecret: suite.nSecret,
  80  	}, nil
  81  }
  82  
  83  func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
  84  	eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
  85  	if err != nil {
  86  		return nil, err
  87  	}
  88  	return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
  89  }
  90  
  91  func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
  92  	var privEph *ecdh.PrivateKey
  93  	if testingOnlyGenerateKey != nil {
  94  		privEph, err = testingOnlyGenerateKey()
  95  	} else {
  96  		privEph, err = dh.dh.GenerateKey(rand.Reader)
  97  	}
  98  	if err != nil {
  99  		return nil, nil, err
 100  	}
 101  	dhVal, err := privEph.ECDH(pubRecipient)
 102  	if err != nil {
 103  		return nil, nil, err
 104  	}
 105  	encPubEph := privEph.PublicKey().Bytes()
 106  
 107  	encPubRecip := pubRecipient.Bytes()
 108  	kemContext := append(encPubEph, encPubRecip...)
 109  	sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext)
 110  	if err != nil {
 111  		return nil, nil, err
 112  	}
 113  	return sharedSecret, encPubEph, nil
 114  }
 115  
 116  func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
 117  	pubEph, err := dh.dh.NewPublicKey(encPubEph)
 118  	if err != nil {
 119  		return nil, err
 120  	}
 121  	dhVal, err := secRecipient.ECDH(pubEph)
 122  	if err != nil {
 123  		return nil, err
 124  	}
 125  	kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
 126  	return dh.ExtractAndExpand(dhVal, kemContext)
 127  }
 128  
 129  type context struct {
 130  	aead cipher.AEAD
 131  
 132  	sharedSecret []byte
 133  
 134  	suiteID []byte
 135  
 136  	key            []byte
 137  	baseNonce      []byte
 138  	exporterSecret []byte
 139  
 140  	seqNum uint128
 141  }
 142  
 143  type Sender struct {
 144  	*context
 145  }
 146  
 147  type Recipient struct {
 148  	*context
 149  }
 150  
 151  var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
 152  	block, err := aes.NewCipher(key)
 153  	if err != nil {
 154  		return nil, err
 155  	}
 156  	return cipher.NewGCM(block)
 157  }
 158  
 159  type AEADID uint16
 160  
 161  const (
 162  	AEAD_AES_128_GCM      = 0x0001
 163  	AEAD_AES_256_GCM      = 0x0002
 164  	AEAD_ChaCha20Poly1305 = 0x0003
 165  )
 166  
 167  var SupportedAEADs = map[uint16]struct {
 168  	keySize   int
 169  	nonceSize int
 170  	aead      func([]byte) (cipher.AEAD, error)
 171  }{
 172  	// RFC 9180, Section 7.3
 173  	AEAD_AES_128_GCM:      {keySize: 16, nonceSize: 12, aead: aesGCMNew},
 174  	AEAD_AES_256_GCM:      {keySize: 32, nonceSize: 12, aead: aesGCMNew},
 175  	AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
 176  }
 177  
 178  type KDFID uint16
 179  
 180  const KDF_HKDF_SHA256 = 0x0001
 181  
 182  var SupportedKDFs = map[uint16]func() *hkdfKDF{
 183  	// RFC 9180, Section 7.2
 184  	KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
 185  }
 186  
 187  func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
 188  	sid := suiteID(kemID, kdfID, aeadID)
 189  
 190  	kdfInit, ok := SupportedKDFs[kdfID]
 191  	if !ok {
 192  		return nil, errors.New("unsupported KDF id")
 193  	}
 194  	kdf := kdfInit()
 195  
 196  	aeadInfo, ok := SupportedAEADs[aeadID]
 197  	if !ok {
 198  		return nil, errors.New("unsupported AEAD id")
 199  	}
 200  
 201  	pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
 202  	if err != nil {
 203  		return nil, err
 204  	}
 205  	infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info)
 206  	if err != nil {
 207  		return nil, err
 208  	}
 209  	ksContext := append([]byte{0}, pskIDHash...)
 210  	ksContext = append(ksContext, infoHash...)
 211  
 212  	secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
 213  	if err != nil {
 214  		return nil, err
 215  	}
 216  	key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
 217  	if err != nil {
 218  		return nil, err
 219  	}
 220  	baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
 221  	if err != nil {
 222  		return nil, err
 223  	}
 224  	exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
 225  	if err != nil {
 226  		return nil, err
 227  	}
 228  
 229  	aead, err := aeadInfo.aead(key)
 230  	if err != nil {
 231  		return nil, err
 232  	}
 233  
 234  	return &context{
 235  		aead:           aead,
 236  		sharedSecret:   sharedSecret,
 237  		suiteID:        sid,
 238  		key:            key,
 239  		baseNonce:      baseNonce,
 240  		exporterSecret: exporterSecret,
 241  	}, nil
 242  }
 243  
 244  func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
 245  	kem, err := newDHKem(kemID)
 246  	if err != nil {
 247  		return nil, nil, err
 248  	}
 249  	sharedSecret, encapsulatedKey, err := kem.Encap(pub)
 250  	if err != nil {
 251  		return nil, nil, err
 252  	}
 253  
 254  	context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
 255  	if err != nil {
 256  		return nil, nil, err
 257  	}
 258  
 259  	return encapsulatedKey, &Sender{context}, nil
 260  }
 261  
 262  func SetupRecipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Recipient, error) {
 263  	kem, err := newDHKem(kemID)
 264  	if err != nil {
 265  		return nil, err
 266  	}
 267  	sharedSecret, err := kem.Decap(encPubEph, priv)
 268  	if err != nil {
 269  		return nil, err
 270  	}
 271  
 272  	context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
 273  	if err != nil {
 274  		return nil, err
 275  	}
 276  
 277  	return &Recipient{context}, nil
 278  }
 279  
 280  func (ctx *context) nextNonce() []byte {
 281  	nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
 282  	for i := range ctx.baseNonce {
 283  		nonce[i] ^= ctx.baseNonce[i]
 284  	}
 285  	return nonce
 286  }
 287  
 288  func (ctx *context) incrementNonce() {
 289  	// Message limit is, according to the RFC, 2^95+1, which
 290  	// is somewhat confusing, but we do as we're told.
 291  	if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
 292  		panic("message limit reached")
 293  	}
 294  	ctx.seqNum = ctx.seqNum.addOne()
 295  }
 296  
 297  func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
 298  	ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
 299  	s.incrementNonce()
 300  	return ciphertext, nil
 301  }
 302  
 303  func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) {
 304  	plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
 305  	if err != nil {
 306  		return nil, err
 307  	}
 308  	r.incrementNonce()
 309  	return plaintext, nil
 310  }
 311  
 312  func suiteID(kemID, kdfID, aeadID uint16) []byte {
 313  	suiteID := []byte{:0:4+2+2+2}
 314  	suiteID = append(suiteID, []byte("HPKE")...)
 315  	suiteID = byteorder.BEAppendUint16(suiteID, kemID)
 316  	suiteID = byteorder.BEAppendUint16(suiteID, kdfID)
 317  	suiteID = byteorder.BEAppendUint16(suiteID, aeadID)
 318  	return suiteID
 319  }
 320  
 321  func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
 322  	kemInfo, ok := SupportedKEMs[kemID]
 323  	if !ok {
 324  		return nil, errors.New("unsupported KEM id")
 325  	}
 326  	return kemInfo.curve.NewPublicKey(bytes)
 327  }
 328  
 329  func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
 330  	kemInfo, ok := SupportedKEMs[kemID]
 331  	if !ok {
 332  		return nil, errors.New("unsupported KEM id")
 333  	}
 334  	return kemInfo.curve.NewPrivateKey(bytes)
 335  }
 336  
 337  type uint128 struct {
 338  	hi, lo uint64
 339  }
 340  
 341  func (u uint128) addOne() uint128 {
 342  	lo, carry := bits.Add64(u.lo, 1, 0)
 343  	return uint128{u.hi + carry, lo}
 344  }
 345  
 346  func (u uint128) bitLen() int {
 347  	return bits.Len64(u.hi) + bits.Len64(u.lo)
 348  }
 349  
 350  func (u uint128) bytes() []byte {
 351  	b := []byte{:16}
 352  	byteorder.BEPutUint64(b[0:], u.hi)
 353  	byteorder.BEPutUint64(b[8:], u.lo)
 354  	return b
 355  }
 356