secret_tree.go raw

   1  package mls
   2  
   3  import (
   4  	"golang.org/x/crypto/cryptobyte"
   5  )
   6  
   7  type ratchetLabel []byte
   8  
   9  var (
  10  	ratchetLabelHandshake   = ratchetLabel("handshake")
  11  	ratchetLabelApplication = ratchetLabel("application")
  12  )
  13  
  14  func ratchetLabelFromContentType(ct contentType) ratchetLabel {
  15  	switch ct {
  16  	case contentTypeApplication:
  17  		return ratchetLabelApplication
  18  	case contentTypeProposal, contentTypeCommit:
  19  		return ratchetLabelHandshake
  20  	default:
  21  		panic("unreachable")
  22  	}
  23  }
  24  
  25  // secretTree holds tree node secrets used for the generation of encryption
  26  // keys and nonces.
  27  type secretTree [][]byte
  28  
  29  func deriveSecretTree(cs CipherSuite, n numLeaves, encryptionSecret []byte) (secretTree, error) {
  30  	tree := make(secretTree, int(n.width()))
  31  	tree.set(n.root(), encryptionSecret)
  32  	err := tree.deriveChildren(cs, n.root())
  33  	return tree, err
  34  }
  35  
  36  func (tree secretTree) deriveChildren(cs CipherSuite, x nodeIndex) error {
  37  	l, r, ok := x.children()
  38  	if !ok {
  39  		return nil
  40  	}
  41  
  42  	parentSecret := tree.get(x)
  43  	_, kdf, _ := cs.hpke().Params()
  44  	nh := uint16(kdf.ExtractSize())
  45  	leftSecret, err := cs.expandWithLabel(parentSecret, []byte("tree"), []byte("left"), nh)
  46  	if err != nil {
  47  		return err
  48  	}
  49  	rightSecret, err := cs.expandWithLabel(parentSecret, []byte("tree"), []byte("right"), nh)
  50  	if err != nil {
  51  		return err
  52  	}
  53  
  54  	tree.set(l, leftSecret)
  55  	tree.set(r, rightSecret)
  56  
  57  	if err := tree.deriveChildren(cs, l); err != nil {
  58  		return err
  59  	}
  60  	if err := tree.deriveChildren(cs, r); err != nil {
  61  		return err
  62  	}
  63  
  64  	return nil
  65  }
  66  
  67  func (tree secretTree) get(ni nodeIndex) []byte {
  68  	secret := tree[int(ni)]
  69  	if secret == nil {
  70  		panic("empty node in secret tree")
  71  	}
  72  	return secret
  73  }
  74  
  75  func (tree secretTree) set(ni nodeIndex, secret []byte) {
  76  	tree[int(ni)] = secret
  77  }
  78  
  79  // deriveRatchetRoot derives the root of a ratchet for a tree node.
  80  func (tree secretTree) deriveRatchetRoot(cs CipherSuite, ni nodeIndex, label ratchetLabel) (ratchetSecret, error) {
  81  	_, kdf, _ := cs.hpke().Params()
  82  	nh := uint16(kdf.ExtractSize())
  83  	root, err := cs.expandWithLabel(tree.get(ni), []byte(label), nil, nh)
  84  	return ratchetSecret{root, 0}, err
  85  }
  86  
  87  type ratchetSecret struct {
  88  	secret     []byte
  89  	generation uint32
  90  }
  91  
  92  func (secret ratchetSecret) deriveNonce(cs CipherSuite) ([]byte, error) {
  93  	_, _, aead := cs.hpke().Params()
  94  	nn := uint16(aead.NonceSize())
  95  	return deriveTreeSecret(cs, secret.secret, []byte("nonce"), secret.generation, nn)
  96  }
  97  
  98  func (secret ratchetSecret) deriveKey(cs CipherSuite) ([]byte, error) {
  99  	_, _, aead := cs.hpke().Params()
 100  	nk := uint16(aead.KeySize())
 101  	return deriveTreeSecret(cs, secret.secret, []byte("key"), secret.generation, nk)
 102  }
 103  
 104  func (secret ratchetSecret) deriveNext(cs CipherSuite) (ratchetSecret, error) {
 105  	_, kdf, _ := cs.hpke().Params()
 106  	nh := uint16(kdf.ExtractSize())
 107  	next, err := deriveTreeSecret(cs, secret.secret, []byte("secret"), secret.generation, nh)
 108  	return ratchetSecret{next, secret.generation + 1}, err
 109  }
 110  
 111  func deriveTreeSecret(cs CipherSuite, secret, label []byte, generation uint32, length uint16) ([]byte, error) {
 112  	var b cryptobyte.Builder
 113  	b.AddUint32(generation)
 114  	context := b.BytesOrPanic()
 115  
 116  	return cs.expandWithLabel(secret, label, context, length)
 117  }
 118