util.go raw

   1  package hpke
   2  
   3  import (
   4  	"encoding/binary"
   5  	"errors"
   6  	"fmt"
   7  )
   8  
   9  func (st state) keySchedule(ss, info, psk, pskID []byte) (*encdecContext, error) {
  10  	if err := st.verifyPSKInputs(psk, pskID); err != nil {
  11  		return nil, err
  12  	}
  13  
  14  	pskIDHash := st.labeledExtract(nil, []byte("psk_id_hash"), pskID)
  15  	infoHash := st.labeledExtract(nil, []byte("info_hash"), info)
  16  	keySchCtx := append(append(
  17  		[]byte{st.modeID},
  18  		pskIDHash...),
  19  		infoHash...)
  20  
  21  	secret := st.labeledExtract(ss, []byte("secret"), psk)
  22  
  23  	Nk := uint16(st.aeadID.KeySize())
  24  	key := st.labeledExpand(secret, []byte("key"), keySchCtx, Nk)
  25  
  26  	aead, err := st.aeadID.New(key)
  27  	if err != nil {
  28  		return nil, err
  29  	}
  30  
  31  	Nn := uint16(aead.NonceSize())
  32  	baseNonce := st.labeledExpand(secret, []byte("base_nonce"), keySchCtx, Nn)
  33  	exporterSecret := st.labeledExpand(
  34  		secret,
  35  		[]byte("exp"),
  36  		keySchCtx,
  37  		uint16(st.kdfID.ExtractSize()),
  38  	)
  39  
  40  	return &encdecContext{
  41  		st.Suite,
  42  		ss,
  43  		secret,
  44  		keySchCtx,
  45  		exporterSecret,
  46  		key,
  47  		baseNonce,
  48  		make([]byte, Nn),
  49  		aead,
  50  		make([]byte, Nn),
  51  	}, nil
  52  }
  53  
  54  func (st state) verifyPSKInputs(psk, pskID []byte) error {
  55  	gotPSK := psk != nil
  56  	gotPSKID := pskID != nil
  57  	if gotPSK != gotPSKID {
  58  		return errors.New("inconsistent PSK inputs")
  59  	}
  60  	switch st.modeID {
  61  	case modeBase | modeAuth:
  62  		if gotPSK {
  63  			return errors.New("PSK input provided when not needed")
  64  		}
  65  	case modePSK | modeAuthPSK:
  66  		if !gotPSK {
  67  			return errors.New("missing required PSK input")
  68  		}
  69  	}
  70  	return nil
  71  }
  72  
  73  // Params returns the codepoints for the algorithms comprising the suite.
  74  func (suite Suite) Params() (KEM, KDF, AEAD) {
  75  	return suite.kemID, suite.kdfID, suite.aeadID
  76  }
  77  
  78  func (suite Suite) String() string {
  79  	return fmt.Sprintf(
  80  		"kem_id: %v kdf_id: %v aead_id: %v",
  81  		suite.kemID, suite.kdfID, suite.aeadID,
  82  	)
  83  }
  84  
  85  func (suite Suite) getSuiteID() (id [10]byte) {
  86  	id[0], id[1], id[2], id[3] = 'H', 'P', 'K', 'E'
  87  	binary.BigEndian.PutUint16(id[4:6], uint16(suite.kemID))
  88  	binary.BigEndian.PutUint16(id[6:8], uint16(suite.kdfID))
  89  	binary.BigEndian.PutUint16(id[8:10], uint16(suite.aeadID))
  90  	return
  91  }
  92  
  93  func (suite Suite) isValid() bool {
  94  	return suite.kemID.IsValid() &&
  95  		suite.kdfID.IsValid() &&
  96  		suite.aeadID.IsValid()
  97  }
  98  
  99  func (suite Suite) labeledExtract(salt, label, ikm []byte) []byte {
 100  	suiteID := suite.getSuiteID()
 101  	labeledIKM := append(append(append(append(
 102  		make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(ikm)),
 103  		versionLabel...),
 104  		suiteID[:]...),
 105  		label...),
 106  		ikm...)
 107  	return suite.kdfID.Extract(labeledIKM, salt)
 108  }
 109  
 110  func (suite Suite) labeledExpand(prk, label, info []byte, l uint16) []byte {
 111  	suiteID := suite.getSuiteID()
 112  	labeledInfo := make([]byte,
 113  		2, 2+len(versionLabel)+len(suiteID)+len(label)+len(info))
 114  	binary.BigEndian.PutUint16(labeledInfo[0:2], l)
 115  	labeledInfo = append(append(append(append(labeledInfo,
 116  		versionLabel...),
 117  		suiteID[:]...),
 118  		label...),
 119  		info...)
 120  	return suite.kdfID.Expand(prk, labeledInfo, uint(l))
 121  }
 122