secret_tree.go raw
1 package mls
2
3 import (
4 "golang.org/x/crypto/cryptobyte"
5 )
6
7 type ratchetLabel []byte
8
9 var (
10 ratchetLabelHandshake = ratchetLabel("handshake")
11 ratchetLabelApplication = ratchetLabel("application")
12 )
13
14 func ratchetLabelFromContentType(ct contentType) ratchetLabel {
15 switch ct {
16 case contentTypeApplication:
17 return ratchetLabelApplication
18 case contentTypeProposal, contentTypeCommit:
19 return ratchetLabelHandshake
20 default:
21 panic("unreachable")
22 }
23 }
24
25 // secretTree holds tree node secrets used for the generation of encryption
26 // keys and nonces.
27 type secretTree [][]byte
28
29 func deriveSecretTree(cs CipherSuite, n numLeaves, encryptionSecret []byte) (secretTree, error) {
30 tree := make(secretTree, int(n.width()))
31 tree.set(n.root(), encryptionSecret)
32 err := tree.deriveChildren(cs, n.root())
33 return tree, err
34 }
35
36 func (tree secretTree) deriveChildren(cs CipherSuite, x nodeIndex) error {
37 l, r, ok := x.children()
38 if !ok {
39 return nil
40 }
41
42 parentSecret := tree.get(x)
43 _, kdf, _ := cs.hpke().Params()
44 nh := uint16(kdf.ExtractSize())
45 leftSecret, err := cs.expandWithLabel(parentSecret, []byte("tree"), []byte("left"), nh)
46 if err != nil {
47 return err
48 }
49 rightSecret, err := cs.expandWithLabel(parentSecret, []byte("tree"), []byte("right"), nh)
50 if err != nil {
51 return err
52 }
53
54 tree.set(l, leftSecret)
55 tree.set(r, rightSecret)
56
57 if err := tree.deriveChildren(cs, l); err != nil {
58 return err
59 }
60 if err := tree.deriveChildren(cs, r); err != nil {
61 return err
62 }
63
64 return nil
65 }
66
67 func (tree secretTree) get(ni nodeIndex) []byte {
68 secret := tree[int(ni)]
69 if secret == nil {
70 panic("empty node in secret tree")
71 }
72 return secret
73 }
74
75 func (tree secretTree) set(ni nodeIndex, secret []byte) {
76 tree[int(ni)] = secret
77 }
78
79 // deriveRatchetRoot derives the root of a ratchet for a tree node.
80 func (tree secretTree) deriveRatchetRoot(cs CipherSuite, ni nodeIndex, label ratchetLabel) (ratchetSecret, error) {
81 _, kdf, _ := cs.hpke().Params()
82 nh := uint16(kdf.ExtractSize())
83 root, err := cs.expandWithLabel(tree.get(ni), []byte(label), nil, nh)
84 return ratchetSecret{root, 0}, err
85 }
86
87 type ratchetSecret struct {
88 secret []byte
89 generation uint32
90 }
91
92 func (secret ratchetSecret) deriveNonce(cs CipherSuite) ([]byte, error) {
93 _, _, aead := cs.hpke().Params()
94 nn := uint16(aead.NonceSize())
95 return deriveTreeSecret(cs, secret.secret, []byte("nonce"), secret.generation, nn)
96 }
97
98 func (secret ratchetSecret) deriveKey(cs CipherSuite) ([]byte, error) {
99 _, _, aead := cs.hpke().Params()
100 nk := uint16(aead.KeySize())
101 return deriveTreeSecret(cs, secret.secret, []byte("key"), secret.generation, nk)
102 }
103
104 func (secret ratchetSecret) deriveNext(cs CipherSuite) (ratchetSecret, error) {
105 _, kdf, _ := cs.hpke().Params()
106 nh := uint16(kdf.ExtractSize())
107 next, err := deriveTreeSecret(cs, secret.secret, []byte("secret"), secret.generation, nh)
108 return ratchetSecret{next, secret.generation + 1}, err
109 }
110
111 func deriveTreeSecret(cs CipherSuite, secret, label []byte, generation uint32, length uint16) ([]byte, error) {
112 var b cryptobyte.Builder
113 b.AddUint32(generation)
114 context := b.BytesOrPanic()
115
116 return cs.expandWithLabel(secret, label, context, length)
117 }
118