aead.go raw

   1  package hpke
   2  
   3  import (
   4  	"crypto/cipher"
   5  	"fmt"
   6  )
   7  
   8  type encdecContext struct {
   9  	// Serialized parameters
  10  	suite              Suite
  11  	sharedSecret       []byte
  12  	secret             []byte
  13  	keyScheduleContext []byte
  14  	exporterSecret     []byte
  15  	key                []byte
  16  	baseNonce          []byte
  17  	sequenceNumber     []byte
  18  
  19  	// Operational parameters
  20  	cipher.AEAD
  21  	nonce []byte
  22  }
  23  
  24  type (
  25  	sealContext struct{ *encdecContext }
  26  	openContext struct{ *encdecContext }
  27  )
  28  
  29  // Export takes a context string exporterContext and a desired length (in
  30  // bytes), and produces a secret derived from the internal exporter secret
  31  // using the corresponding KDF Expand function. It panics if length is
  32  // greater than 255*N bytes, where N is the size (in bytes) of the KDF's
  33  // output.
  34  func (c *encdecContext) Export(exporterContext []byte, length uint) []byte {
  35  	maxLength := uint(255 * c.suite.kdfID.ExtractSize())
  36  	if length > maxLength {
  37  		panic(fmt.Errorf("output length must be lesser than %v bytes", maxLength))
  38  	}
  39  	return c.suite.labeledExpand(c.exporterSecret, []byte("sec"),
  40  		exporterContext, uint16(length))
  41  }
  42  
  43  func (c *encdecContext) Suite() Suite {
  44  	return c.suite
  45  }
  46  
  47  func (c *encdecContext) calcNonce() []byte {
  48  	for i := range c.baseNonce {
  49  		c.nonce[i] = c.baseNonce[i] ^ c.sequenceNumber[i]
  50  	}
  51  	return c.nonce
  52  }
  53  
  54  func (c *encdecContext) increment() error {
  55  	// tests whether the sequence number is all-ones, which prevents an
  56  	// overflow after the increment.
  57  	allOnes := byte(0xFF)
  58  	for i := range c.sequenceNumber {
  59  		allOnes &= c.sequenceNumber[i]
  60  	}
  61  	if allOnes == byte(0xFF) {
  62  		return ErrAEADSeqOverflows
  63  	}
  64  
  65  	// performs an increment by 1 and verifies whether the sequence overflows.
  66  	carry := uint(1)
  67  	for i := len(c.sequenceNumber) - 1; i >= 0; i-- {
  68  		sum := uint(c.sequenceNumber[i]) + carry
  69  		carry = sum >> 8
  70  		c.sequenceNumber[i] = byte(sum & 0xFF)
  71  	}
  72  	if carry != 0 {
  73  		return ErrAEADSeqOverflows
  74  	}
  75  	return nil
  76  }
  77  
  78  func (c *sealContext) Seal(pt, aad []byte) ([]byte, error) {
  79  	ct := c.AEAD.Seal(nil, c.calcNonce(), pt, aad)
  80  	err := c.increment()
  81  	if err != nil {
  82  		for i := range ct {
  83  			ct[i] = 0
  84  		}
  85  		return nil, err
  86  	}
  87  	return ct, nil
  88  }
  89  
  90  func (c *openContext) Open(ct, aad []byte) ([]byte, error) {
  91  	pt, err := c.AEAD.Open(nil, c.calcNonce(), ct, aad)
  92  	if err != nil {
  93  		return nil, err
  94  	}
  95  	err = c.increment()
  96  	if err != nil {
  97  		for i := range pt {
  98  			pt[i] = 0
  99  		}
 100  		return nil, err
 101  	}
 102  	return pt, nil
 103  }
 104