secret_tree_crypto.mx raw
1 package mls
2
3 // MLS secret tree crypto operations (RFC 9420 ยง9).
4
5 func deriveSecretTree(cs CipherSuite, n numLeaves, encryptionSecret []byte) (secretTree, error) {
6 tree := secretTree([][]byte{:int(n.width())})
7 tree.set(n.root(), encryptionSecret)
8 err := tree.deriveChildren(cs, n.root())
9 return tree, err
10 }
11
12 func (tree secretTree) deriveChildren(cs CipherSuite, x nodeIndex) error {
13 l, r, ok := x.children()
14 if !ok {
15 return nil
16 }
17
18 parentSecret := tree.get(x)
19 nh := uint16(cs.ExtractSize())
20 leftSecret, err := cs.expandWithLabel(parentSecret, []byte("tree"), []byte("left"), nh)
21 if err != nil {
22 return err
23 }
24 rightSecret, err := cs.expandWithLabel(parentSecret, []byte("tree"), []byte("right"), nh)
25 if err != nil {
26 return err
27 }
28
29 tree.set(l, leftSecret)
30 tree.set(r, rightSecret)
31
32 if err := tree.deriveChildren(cs, l); err != nil {
33 return err
34 }
35 return tree.deriveChildren(cs, r)
36 }
37
38 func (tree secretTree) deriveRatchetRoot(cs CipherSuite, ni nodeIndex, label ratchetLabel) (ratchetSecret, error) {
39 nh := uint16(cs.ExtractSize())
40 root, err := cs.expandWithLabel(tree.get(ni), []byte(label), nil, nh)
41 return ratchetSecret{root, 0}, err
42 }
43
44 func (secret ratchetSecret) deriveNonce(cs CipherSuite) ([]byte, error) {
45 return deriveTreeSecret(cs, secret.secret, []byte("nonce"), secret.generation, uint16(cs.AEADNonceSize()))
46 }
47
48 func (secret ratchetSecret) deriveKey(cs CipherSuite) ([]byte, error) {
49 return deriveTreeSecret(cs, secret.secret, []byte("key"), secret.generation, uint16(cs.AEADKeySize()))
50 }
51
52 func (secret ratchetSecret) deriveNext(cs CipherSuite) (ratchetSecret, error) {
53 nh := uint16(cs.ExtractSize())
54 next, err := deriveTreeSecret(cs, secret.secret, []byte("secret"), secret.generation, nh)
55 return ratchetSecret{next, secret.generation + 1}, err
56 }
57
58 func deriveTreeSecret(cs CipherSuite, secret, label []byte, generation uint32, length uint16) ([]byte, error) {
59 // context = I2OSP(generation, 4)
60 context := []byte{byte(generation >> 24), byte(generation >> 16), byte(generation >> 8), byte(generation)}
61 return cs.expandWithLabel(secret, label, context, length)
62 }
63