kembase.go raw

   1  package hpke
   2  
   3  import (
   4  	"crypto"
   5  	"crypto/rand"
   6  	"encoding/binary"
   7  	"io"
   8  
   9  	"github.com/cloudflare/circl/kem"
  10  	"golang.org/x/crypto/hkdf"
  11  )
  12  
  13  type dhKEM interface {
  14  	sizeDH() int
  15  	calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error
  16  	SeedSize() int
  17  	DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey)
  18  	UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error)
  19  	UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error)
  20  }
  21  
  22  type kemBase struct {
  23  	id   KEM
  24  	name string
  25  	crypto.Hash
  26  }
  27  
  28  type dhKemBase struct {
  29  	kemBase
  30  	dhKEM
  31  }
  32  
  33  func (k kemBase) Name() string       { return k.name }
  34  func (k kemBase) SharedKeySize() int { return k.Hash.Size() }
  35  
  36  func (k kemBase) getSuiteID() (sid [5]byte) {
  37  	sid[0], sid[1], sid[2] = 'K', 'E', 'M'
  38  	binary.BigEndian.PutUint16(sid[3:5], uint16(k.id))
  39  	return
  40  }
  41  
  42  func (k kemBase) extractExpand(dh, kemCtx []byte) []byte {
  43  	eaePkr := k.labeledExtract([]byte(""), []byte("eae_prk"), dh)
  44  	return k.labeledExpand(
  45  		eaePkr,
  46  		[]byte("shared_secret"),
  47  		kemCtx,
  48  		uint16(k.Size()),
  49  	)
  50  }
  51  
  52  func (k kemBase) labeledExtract(salt, label, info []byte) []byte {
  53  	suiteID := k.getSuiteID()
  54  	labeledIKM := append(append(append(append(
  55  		make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(info)),
  56  		versionLabel...),
  57  		suiteID[:]...),
  58  		label...),
  59  		info...)
  60  	return hkdf.Extract(k.New, labeledIKM, salt)
  61  }
  62  
  63  func (k kemBase) labeledExpand(prk, label, info []byte, l uint16) []byte {
  64  	suiteID := k.getSuiteID()
  65  	labeledInfo := make(
  66  		[]byte,
  67  		2,
  68  		2+len(versionLabel)+len(suiteID)+len(label)+len(info),
  69  	)
  70  	binary.BigEndian.PutUint16(labeledInfo[0:2], l)
  71  	labeledInfo = append(append(append(append(labeledInfo,
  72  		versionLabel...),
  73  		suiteID[:]...),
  74  		label...),
  75  		info...)
  76  	b := make([]byte, l)
  77  	rd := hkdf.Expand(k.New, prk, labeledInfo)
  78  	if _, err := io.ReadFull(rd, b); err != nil {
  79  		panic(err)
  80  	}
  81  	return b
  82  }
  83  
  84  func (k dhKemBase) AuthEncapsulate(pkr kem.PublicKey, sks kem.PrivateKey) (
  85  	ct []byte, ss []byte, err error,
  86  ) {
  87  	seed := make([]byte, k.SeedSize())
  88  	_, err = io.ReadFull(rand.Reader, seed)
  89  	if err != nil {
  90  		return nil, nil, err
  91  	}
  92  
  93  	return k.authEncap(pkr, sks, seed)
  94  }
  95  
  96  func (k dhKemBase) Encapsulate(pkr kem.PublicKey) (
  97  	ct []byte, ss []byte, err error,
  98  ) {
  99  	seed := make([]byte, k.SeedSize())
 100  	_, err = io.ReadFull(rand.Reader, seed)
 101  	if err != nil {
 102  		return nil, nil, err
 103  	}
 104  
 105  	return k.encap(pkr, seed)
 106  }
 107  
 108  func (k dhKemBase) AuthEncapsulateDeterministically(
 109  	pkr kem.PublicKey, sks kem.PrivateKey, seed []byte,
 110  ) (ct, ss []byte, err error) {
 111  	return k.authEncap(pkr, sks, seed)
 112  }
 113  
 114  func (k dhKemBase) EncapsulateDeterministically(
 115  	pkr kem.PublicKey, seed []byte,
 116  ) (ct, ss []byte, err error) {
 117  	return k.encap(pkr, seed)
 118  }
 119  
 120  func (k dhKemBase) encap(
 121  	pkR kem.PublicKey,
 122  	seed []byte,
 123  ) (ct []byte, ss []byte, err error) {
 124  	dh := make([]byte, k.sizeDH())
 125  	enc, kemCtx, err := k.coreEncap(dh, pkR, seed)
 126  	if err != nil {
 127  		return nil, nil, err
 128  	}
 129  	ss = k.extractExpand(dh, kemCtx)
 130  	return enc, ss, nil
 131  }
 132  
 133  func (k dhKemBase) authEncap(
 134  	pkR kem.PublicKey,
 135  	skS kem.PrivateKey,
 136  	seed []byte,
 137  ) (ct []byte, ss []byte, err error) {
 138  	dhLen := k.sizeDH()
 139  	dh := make([]byte, 2*dhLen)
 140  	enc, kemCtx, err := k.coreEncap(dh[:dhLen], pkR, seed)
 141  	if err != nil {
 142  		return nil, nil, err
 143  	}
 144  
 145  	err = k.calcDH(dh[dhLen:], skS, pkR)
 146  	if err != nil {
 147  		return nil, nil, err
 148  	}
 149  
 150  	pkS := skS.Public()
 151  	pkSm, err := pkS.MarshalBinary()
 152  	if err != nil {
 153  		return nil, nil, err
 154  	}
 155  	kemCtx = append(kemCtx, pkSm...)
 156  
 157  	ss = k.extractExpand(dh, kemCtx)
 158  	return enc, ss, nil
 159  }
 160  
 161  func (k dhKemBase) coreEncap(
 162  	dh []byte,
 163  	pkR kem.PublicKey,
 164  	seed []byte,
 165  ) (enc []byte, kemCtx []byte, err error) {
 166  	pkE, skE := k.DeriveKeyPair(seed)
 167  	err = k.calcDH(dh, skE, pkR)
 168  	if err != nil {
 169  		return nil, nil, err
 170  	}
 171  
 172  	enc, err = pkE.MarshalBinary()
 173  	if err != nil {
 174  		return nil, nil, err
 175  	}
 176  	pkRm, err := pkR.MarshalBinary()
 177  	if err != nil {
 178  		return nil, nil, err
 179  	}
 180  	kemCtx = append(append([]byte{}, enc...), pkRm...)
 181  
 182  	return enc, kemCtx, nil
 183  }
 184  
 185  func (k dhKemBase) Decapsulate(skr kem.PrivateKey, ct []byte) ([]byte, error) {
 186  	dh := make([]byte, k.sizeDH())
 187  	kemCtx, err := k.coreDecap(dh, skr, ct)
 188  	if err != nil {
 189  		return nil, err
 190  	}
 191  	return k.extractExpand(dh, kemCtx), nil
 192  }
 193  
 194  func (k dhKemBase) AuthDecapsulate(
 195  	skR kem.PrivateKey,
 196  	ct []byte,
 197  	pkS kem.PublicKey,
 198  ) ([]byte, error) {
 199  	dhLen := k.sizeDH()
 200  	dh := make([]byte, 2*dhLen)
 201  	kemCtx, err := k.coreDecap(dh[:dhLen], skR, ct)
 202  	if err != nil {
 203  		return nil, err
 204  	}
 205  
 206  	err = k.calcDH(dh[dhLen:], skR, pkS)
 207  	if err != nil {
 208  		return nil, err
 209  	}
 210  
 211  	pkSm, err := pkS.MarshalBinary()
 212  	if err != nil {
 213  		return nil, err
 214  	}
 215  	kemCtx = append(kemCtx, pkSm...)
 216  	return k.extractExpand(dh, kemCtx), nil
 217  }
 218  
 219  func (k dhKemBase) coreDecap(
 220  	dh []byte,
 221  	skR kem.PrivateKey,
 222  	ct []byte,
 223  ) ([]byte, error) {
 224  	pkE, err := k.UnmarshalBinaryPublicKey(ct)
 225  	if err != nil {
 226  		return nil, err
 227  	}
 228  
 229  	err = k.calcDH(dh, skR, pkE)
 230  	if err != nil {
 231  		return nil, err
 232  	}
 233  
 234  	pkR := skR.Public()
 235  	pkRm, err := pkR.MarshalBinary()
 236  	if err != nil {
 237  		return nil, err
 238  	}
 239  
 240  	return append(append([]byte{}, ct...), pkRm...), nil
 241  }
 242