tree_crypto.mx raw
1 package mls
2
3 // MLS tree crypto operations (RFC 9420 §7).
4
5 import "errors"
6
7 var (
8 errNodeKeyMismatch = errors.New("mls: node public key mismatch")
9 )
10
11 func (node *parentNode) computeParentHash(cs CipherSuite, originalSiblingTreeHash []byte) ([]byte, error) {
12 raw, err := marshalParentHashInput(node.encryptionKey, node.parentHash, originalSiblingTreeHash)
13 if err != nil {
14 return nil, err
15 }
16 return cs.hash(raw), nil
17 }
18
19 func marshalParentHashInput(encryptionKey hpkePublicKey, parentHash, originalSiblingTreeHash []byte) ([]byte, error) {
20 var w Writer
21 w.writeOpaqueVec([]byte(encryptionKey))
22 w.writeOpaqueVec(parentHash)
23 w.writeOpaqueVec(originalSiblingTreeHash)
24 return w.bytes()
25 }
26
27 func (node *leafNode) sign(cs CipherSuite, groupID GroupID, li leafIndex, signerPriv signaturePrivateKey) error {
28 tbs, err := marshalRaw(&leafNodeTBS{
29 node: node,
30 groupID: groupID,
31 leafIndex: li,
32 })
33 if err != nil {
34 return err
35 }
36 sig, err := cs.signWithLabel(signerPriv, []byte("LeafNodeTBS"), tbs)
37 if err != nil {
38 return err
39 }
40 node.signature = sig
41 return nil
42 }
43
44 func (node *leafNode) verifySignature(cs CipherSuite, groupID GroupID, li leafIndex) bool {
45 tbs, err := marshalRaw(&leafNodeTBS{
46 node: node,
47 groupID: groupID,
48 leafIndex: li,
49 })
50 if err != nil {
51 return false
52 }
53 return cs.verifyWithLabel(node.signatureKey, []byte("LeafNodeTBS"), tbs, node.signature)
54 }
55
56 func decryptPathSecret(cs CipherSuite, nodePriv hpkePrivateKey, ctx *groupContext, ct hpkeCiphertext) ([]byte, error) {
57 rawCtx, err := marshalRaw(ctx)
58 if err != nil {
59 return nil, err
60 }
61 return cs.decryptWithLabel(nodePriv, []byte("UpdatePathNode"), rawCtx, ct.kemOutput, ct.ciphertext)
62 }
63
64 func nodePrivFromPathSecret(cs CipherSuite, pathSecret []byte, nodePub hpkePublicKey) (hpkePrivateKey, error) {
65 nodeSecret, err := cs.deriveSecret(pathSecret, []byte("node"))
66 if err != nil {
67 return nil, err
68 }
69 pub, priv, err := cs.deriveEncryptionKeyPair(nodeSecret)
70 if err != nil {
71 return nil, err
72 }
73 if !bytesEqual(pub, nodePub) {
74 return nil, errNodeKeyMismatch
75 }
76 return priv, nil
77 }
78
79 // computeRootTreeHash computes the tree hash for integrity verification.
80 func (tree ratchetTree) computeRootTreeHash(cs CipherSuite) ([]byte, error) {
81 n := tree.numLeaves()
82 return tree.computeTreeHash(cs, n.root(), n, nil)
83 }
84
85 func (tree ratchetTree) computeTreeHash(cs CipherSuite, x nodeIndex, n numLeaves, exclude map[leafIndex]bool) ([]byte, error) {
86 if x.isLeaf() {
87 return tree.computeLeafTreeHash(cs, x, exclude)
88 }
89 return tree.computeParentTreeHash(cs, x, n, exclude)
90 }
91
92 func (tree ratchetTree) computeLeafTreeHash(cs CipherSuite, x nodeIndex, exclude map[leafIndex]bool) ([]byte, error) {
93 var w Writer
94 li, _ := x.leafIndex()
95 w.addUint32(uint32(li))
96 nd := tree.get(x)
97 if exclude[li] {
98 nd = nil
99 }
100 w.writeOptional(nd != nil)
101 if nd != nil {
102 nd.leafNode.marshal(&w)
103 }
104 input, err := w.bytes()
105 if err != nil {
106 return nil, err
107 }
108 return cs.hash(input), nil
109 }
110
111 func (tree ratchetTree) computeParentTreeHash(cs CipherSuite, x nodeIndex, n numLeaves, exclude map[leafIndex]bool) ([]byte, error) {
112 l, r, ok := x.children()
113 if !ok {
114 return nil, errUnexpectedEOF
115 }
116 leftHash, err := tree.computeTreeHash(cs, l, n, exclude)
117 if err != nil {
118 return nil, err
119 }
120 rightHash, err := tree.computeTreeHash(cs, r, n, exclude)
121 if err != nil {
122 return nil, err
123 }
124
125 var w Writer
126 nd := tree.get(x)
127 if nd != nil && len(exclude) > 0 {
128 // Filter unmerged leaves for parent hash verification
129 pn := nd.parentNode
130 if pn != nil && len(pn.unmergedLeaves) > 0 {
131 filtered := []leafIndex{:0:len(pn.unmergedLeaves)}
132 for _, li := range pn.unmergedLeaves {
133 if !exclude[li] {
134 filtered = append(filtered, li)
135 }
136 }
137 filteredNode := *pn
138 filteredNode.unmergedLeaves = filtered
139 w.writeOptional(true)
140 filteredNode.marshal(&w)
141 w.writeOpaqueVec(leftHash)
142 w.writeOpaqueVec(rightHash)
143 input, err := w.bytes()
144 if err != nil {
145 return nil, err
146 }
147 return cs.hash(input), nil
148 }
149 }
150 w.writeOptional(nd != nil)
151 if nd != nil {
152 nd.parentNode.marshal(&w)
153 }
154 w.writeOpaqueVec(leftHash)
155 w.writeOpaqueVec(rightHash)
156 input, err := w.bytes()
157 if err != nil {
158 return nil, err
159 }
160 return cs.hash(input), nil
161 }
162
163 // verifyParentHashes verifies the parent hash chain (RFC 9420 §7.9.2).
164 func (tree ratchetTree) verifyParentHashes(cs CipherSuite) bool {
165 n := tree.numLeaves()
166 for i, nd := range tree {
167 if nd == nil {
168 continue
169 }
170 x := nodeIndex(i)
171 l, r, ok := x.children()
172 if !ok {
173 continue
174 }
175
176 pn := nd.parentNode
177 exclude := map[leafIndex]bool{}
178 for _, li := range pn.unmergedLeaves {
179 exclude[li] = true
180 }
181
182 leftTreeHash, err := tree.computeTreeHash(cs, l, n, exclude)
183 if err != nil {
184 return false
185 }
186 rightTreeHash, err := tree.computeTreeHash(cs, r, n, exclude)
187 if err != nil {
188 return false
189 }
190
191 leftParentHash, err := pn.computeParentHash(cs, rightTreeHash)
192 if err != nil {
193 return false
194 }
195 rightParentHash, err := pn.computeParentHash(cs, leftTreeHash)
196 if err != nil {
197 return false
198 }
199
200 isLeft := tree.findParentHash(tree.resolve(l), leftParentHash)
201 isRight := tree.findParentHash(tree.resolve(r), rightParentHash)
202 if isLeft == isRight {
203 return false
204 }
205 }
206 return true
207 }
208
209 // verifyIntegrity verifies the ratchet tree (RFC 9420 §12.4.3.1).
210 func (tree ratchetTree) verifyIntegrity(ctx *groupContext, nowUnix int64) error {
211 cs := ctx.cipherSuite
212 n := tree.numLeaves()
213
214 h, err := tree.computeRootTreeHash(cs)
215 if err != nil {
216 return err
217 }
218 if !bytesEqual(h, ctx.treeHash) {
219 return errors.New("mls: tree hash verification failed")
220 }
221 if !tree.verifyParentHashes(cs) {
222 return errors.New("mls: parent hash verification failed")
223 }
224
225 supportedCreds := tree.supportedCreds()
226 sigKeys := map[string]bool{}
227 encKeys := map[string]bool{}
228 for li := leafIndex(0); li < leafIndex(n); li++ {
229 nd := tree.getLeaf(li)
230 if nd == nil {
231 continue
232 }
233 err := nd.verify(&leafNodeVerifyOptions{
234 cipherSuite: cs,
235 groupID: ctx.groupID,
236 leafIndex: li,
237 supportedCreds: supportedCreds,
238 signatureKeys: sigKeys,
239 encryptionKeys: encKeys,
240 nowUnix: nowUnix,
241 })
242 if err != nil {
243 return err
244 }
245 sigKeys[string(nd.signatureKey)] = true
246 encKeys[string(nd.encryptionKey)] = true
247 }
248
249 // Check unmerged leaf ancestry
250 for i, nd := range tree {
251 if nd == nil || nd.nodeType != nodeTypeParent {
252 continue
253 }
254 p := nodeIndex(i)
255 for _, ul := range nd.parentNode.unmergedLeaves {
256 x := ul.nodeIndex()
257 for {
258 var ok bool
259 x, ok = n.parent(x)
260 if !ok {
261 return errors.New("mls: unmerged leaf not descendant of parent")
262 }
263 if x == p {
264 break
265 }
266 intermediate := tree.get(x)
267 if intermediate != nil && !hasUnmergedLeaf(intermediate.parentNode, ul) {
268 return errors.New("mls: intermediate node missing unmerged leaf")
269 }
270 }
271 }
272 if encKeys[string(nd.parentNode.encryptionKey)] {
273 return errors.New("mls: duplicate encryption key in tree")
274 }
275 encKeys[string(nd.parentNode.encryptionKey)] = true
276 }
277 return nil
278 }
279
280 // mergeUpdatePath applies an update path to the tree (RFC 9420 §7.5).
281 func (tree ratchetTree) mergeUpdatePath(cs CipherSuite, senderLI leafIndex, path *updatePath) error {
282 senderNI := senderLI.nodeIndex()
283 n := tree.numLeaves()
284
285 directPath := n.directPath(senderNI)
286 for _, ni := range directPath {
287 tree.set(ni, nil)
288 }
289
290 filteredDP := tree.filteredDirectPath(senderNI)
291 if len(filteredDP) != len(path.nodes) {
292 return errors.New("mls: update path length mismatch")
293 }
294 for i, ni := range filteredDP {
295 tree.set(ni, &node{
296 nodeType: nodeTypeParent,
297 parentNode: &parentNode{
298 encryptionKey: path.nodes[i].encryptionKey,
299 },
300 })
301 }
302
303 // Compute parent hashes root-to-leaf
304 var prevParentHash []byte
305 for i := len(filteredDP) - 1; i >= 0; i-- {
306 ni := filteredDP[i]
307 pn := tree.get(ni).parentNode
308
309 l, r, ok := ni.children()
310 if !ok {
311 panic("unreachable")
312 }
313 s := l
314 found := false
315 for _, dp := range directPath {
316 if dp == s {
317 found = true
318 break
319 }
320 }
321 if s == senderNI || found {
322 s = r
323 }
324
325 treeHash, err := tree.computeTreeHash(cs, s, n, nil)
326 if err != nil {
327 return err
328 }
329
330 pn.parentHash = prevParentHash
331 h, err := pn.computeParentHash(cs, treeHash)
332 if err != nil {
333 return err
334 }
335 prevParentHash = h
336 }
337
338 if !bytesEqual(path.leafNode.parentHash, prevParentHash) {
339 return errors.New("mls: parent hash mismatch for update path leaf node")
340 }
341
342 tree.set(senderNI, &node{
343 nodeType: nodeTypeLeaf,
344 leafNode: &path.leafNode,
345 })
346 return nil
347 }
348
349 // decryptPathSecrets decrypts path secrets from an update path (RFC 9420 §7.6).
350 func (tree ratchetTree) decryptPathSecrets(cs CipherSuite, ctx *groupContext, senderLI, recipientLI leafIndex, path *updatePath, privTree []hpkePrivateKey) ([]byte, error) {
351 senderNI := senderLI.nodeIndex()
352 recipientNI := recipientLI.nodeIndex()
353
354 senderFDP := tree.filteredDirectPath(senderNI)
355 if len(path.nodes) != len(senderFDP) {
356 return nil, errors.New("mls: invalid update path length")
357 }
358
359 // Find the common ancestor in the filtered direct path
360 recipientAncestor := commonAncestor(senderNI, recipientNI)
361 recipientAncestorIdx := -1
362 for i, ni := range senderFDP {
363 if ni == recipientAncestor {
364 recipientAncestorIdx = i
365 break
366 }
367 }
368 if recipientAncestorIdx < 0 {
369 return nil, errors.New("mls: cannot find recipient ancestor")
370 }
371 upNode := path.nodes[recipientAncestorIdx]
372
373 // Find the copath node
374 ancestor := commonAncestor(senderNI, recipientNI)
375 var copathNode nodeIndex
376 var ok bool
377 if recipientNI < senderNI {
378 copathNode, ok = ancestor.left()
379 } else {
380 copathNode, ok = ancestor.right()
381 }
382 if !ok {
383 panic("unreachable")
384 }
385
386 copathRes := tree.resolve(copathNode)
387 if len(upNode.encryptedPathSecret) != len(copathRes) {
388 return nil, errors.New("mls: invalid encrypted path secret length")
389 }
390
391 // Find a node in the resolution for which we have a private key
392 var nodePriv hpkePrivateKey
393 resIdx := -1
394 for i, ni := range copathRes {
395 if p := privTree[int(ni)]; p != nil {
396 nodePriv = p
397 resIdx = i
398 break
399 }
400 }
401 if nodePriv == nil {
402 return nil, errors.New("mls: no private key found")
403 }
404
405 pathSecret, err := decryptPathSecret(cs, nodePriv, ctx, upNode.encryptedPathSecret[resIdx])
406 if err != nil {
407 return nil, err
408 }
409 nodePub := tree.get(recipientAncestor).encryptionKey()
410 nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, nodePub)
411 if err != nil {
412 return nil, err
413 }
414 privTree[int(recipientAncestor)] = nodePriv
415
416 // Derive path secrets for remaining ancestors
417 for _, ni := range senderFDP[recipientAncestorIdx+1:] {
418 pathSecret, err = cs.deriveSecret(pathSecret, []byte("path"))
419 if err != nil {
420 return nil, err
421 }
422 nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, tree.get(ni).encryptionKey())
423 if err != nil {
424 return nil, err
425 }
426 privTree[int(ni)] = nodePriv
427 }
428
429 commitSecret, err := cs.deriveSecret(pathSecret, []byte("path"))
430 if err != nil {
431 return nil, err
432 }
433 return commitSecret, nil
434 }
435