tree.go raw
1 package mls
2
3 import (
4 "bytes"
5 "fmt"
6 "io"
7 "time"
8
9 "golang.org/x/crypto/cryptobyte"
10 )
11
12 type parentNode struct {
13 encryptionKey hpkePublicKey
14 parentHash []byte
15 unmergedLeaves []leafIndex
16 }
17
18 func (node *parentNode) unmarshal(s *cryptobyte.String) error {
19 *node = parentNode{}
20 if !readOpaqueVec(s, (*[]byte)(&node.encryptionKey)) || !readOpaqueVec(s, &node.parentHash) {
21 return io.ErrUnexpectedEOF
22 }
23 return readVector(s, func(s *cryptobyte.String) error {
24 var i leafIndex
25 if !s.ReadUint32((*uint32)(&i)) {
26 return io.ErrUnexpectedEOF
27 }
28 node.unmergedLeaves = append(node.unmergedLeaves, i)
29 return nil
30 })
31 }
32
33 func (node *parentNode) marshal(b *cryptobyte.Builder) {
34 writeOpaqueVec(b, []byte(node.encryptionKey))
35 writeOpaqueVec(b, node.parentHash)
36 writeVector(b, len(node.unmergedLeaves), func(b *cryptobyte.Builder, i int) {
37 b.AddUint32(uint32(node.unmergedLeaves[i]))
38 })
39 }
40
41 func (node *parentNode) computeParentHash(cs CipherSuite, originalSiblingTreeHash []byte) ([]byte, error) {
42 rawInput, err := marshalParentHashInput(node.encryptionKey, node.parentHash, originalSiblingTreeHash)
43 if err != nil {
44 return nil, err
45 }
46 h := cs.hash().New()
47 h.Write(rawInput)
48 return h.Sum(nil), nil
49 }
50
51 func marshalParentHashInput(encryptionKey hpkePublicKey, parentHash, originalSiblingTreeHash []byte) ([]byte, error) {
52 var b cryptobyte.Builder
53 writeOpaqueVec(&b, []byte(encryptionKey))
54 writeOpaqueVec(&b, parentHash)
55 writeOpaqueVec(&b, originalSiblingTreeHash)
56 return b.Bytes()
57 }
58
59 type leafNodeSource uint8
60
61 const (
62 leafNodeSourceKeyPackage leafNodeSource = 1
63 leafNodeSourceUpdate leafNodeSource = 2
64 leafNodeSourceCommit leafNodeSource = 3
65 )
66
67 func (src *leafNodeSource) unmarshal(s *cryptobyte.String) error {
68 if !s.ReadUint8((*uint8)(src)) {
69 return io.ErrUnexpectedEOF
70 }
71 switch *src {
72 case leafNodeSourceKeyPackage, leafNodeSourceUpdate, leafNodeSourceCommit:
73 return nil
74 default:
75 return fmt.Errorf("mls: invalid leaf node source %d", *src)
76 }
77 }
78
79 func (src leafNodeSource) marshal(b *cryptobyte.Builder) {
80 b.AddUint8(uint8(src))
81 }
82
83 type capabilities struct {
84 versions []protocolVersion
85 cipherSuites []CipherSuite
86 extensions []extensionType
87 proposals []proposalType
88 credentials []credentialType
89 }
90
91 func (caps *capabilities) unmarshal(s *cryptobyte.String) error {
92 *caps = capabilities{}
93
94 // Note: all unknown values here must be ignored
95
96 err := readVector(s, func(s *cryptobyte.String) error {
97 var ver protocolVersion
98 if !s.ReadUint16((*uint16)(&ver)) {
99 return io.ErrUnexpectedEOF
100 }
101 caps.versions = append(caps.versions, ver)
102 return nil
103 })
104 if err != nil {
105 return err
106 }
107
108 err = readVector(s, func(s *cryptobyte.String) error {
109 var cs CipherSuite
110 if !s.ReadUint16((*uint16)(&cs)) {
111 return io.ErrUnexpectedEOF
112 }
113 caps.cipherSuites = append(caps.cipherSuites, cs)
114 return nil
115 })
116 if err != nil {
117 return err
118 }
119
120 err = readVector(s, func(s *cryptobyte.String) error {
121 var et extensionType
122 if !s.ReadUint16((*uint16)(&et)) {
123 return io.ErrUnexpectedEOF
124 }
125 caps.extensions = append(caps.extensions, et)
126 return nil
127 })
128 if err != nil {
129 return err
130 }
131
132 err = readVector(s, func(s *cryptobyte.String) error {
133 var pt proposalType
134 if !s.ReadUint16((*uint16)(&pt)) {
135 return io.ErrUnexpectedEOF
136 }
137 caps.proposals = append(caps.proposals, pt)
138 return nil
139 })
140 if err != nil {
141 return err
142 }
143
144 err = readVector(s, func(s *cryptobyte.String) error {
145 var ct credentialType
146 if !s.ReadUint16((*uint16)(&ct)) {
147 return io.ErrUnexpectedEOF
148 }
149 caps.credentials = append(caps.credentials, ct)
150 return nil
151 })
152 if err != nil {
153 return err
154 }
155
156 return nil
157 }
158
159 func (caps *capabilities) marshal(b *cryptobyte.Builder) {
160 writeVector(b, len(caps.versions), func(b *cryptobyte.Builder, i int) {
161 b.AddUint16(uint16(caps.versions[i]))
162 })
163
164 writeVector(b, len(caps.cipherSuites), func(b *cryptobyte.Builder, i int) {
165 b.AddUint16(uint16(caps.cipherSuites[i]))
166 })
167
168 writeVector(b, len(caps.extensions), func(b *cryptobyte.Builder, i int) {
169 b.AddUint16(uint16(caps.extensions[i]))
170 })
171
172 writeVector(b, len(caps.proposals), func(b *cryptobyte.Builder, i int) {
173 b.AddUint16(uint16(caps.proposals[i]))
174 })
175
176 writeVector(b, len(caps.credentials), func(b *cryptobyte.Builder, i int) {
177 b.AddUint16(uint16(caps.credentials[i]))
178 })
179 }
180
181 const maxLeafNodeLifetime = 3 * 30 * 24 * time.Hour
182
183 type lifetime struct {
184 notBefore, notAfter uint64
185 }
186
187 func newLifetime(notBefore, notAfter time.Time) *lifetime {
188 return &lifetime{
189 notBefore: uint64(notBefore.Unix()),
190 notAfter: uint64(notAfter.Unix()),
191 }
192 }
193
194 func (lt *lifetime) unmarshal(s *cryptobyte.String) error {
195 *lt = lifetime{}
196 if !s.ReadUint64(<.notBefore) || !s.ReadUint64(<.notAfter) {
197 return io.ErrUnexpectedEOF
198 }
199 return nil
200 }
201
202 func (lt *lifetime) marshal(b *cryptobyte.Builder) {
203 b.AddUint64(lt.notBefore)
204 b.AddUint64(lt.notAfter)
205 }
206
207 func (lt *lifetime) notBeforeTime() time.Time {
208 return time.Unix(int64(lt.notBefore), 0)
209 }
210
211 func (lt *lifetime) notAfterTime() time.Time {
212 return time.Unix(int64(lt.notAfter), 0)
213 }
214
215 // verify ensures that the lifetime is valid: it has an acceptable range and
216 // the current time is within that range.
217 func (lt *lifetime) verify(t time.Time) bool {
218 notBefore, notAfter := lt.notBeforeTime(), lt.notAfterTime()
219
220 if d := notAfter.Sub(notBefore); d <= 0 || d > maxLeafNodeLifetime {
221 return false
222 }
223
224 return t.After(notBefore) && notAfter.After(t)
225 }
226
227 type extensionType uint16
228
229 // http://www.iana.org/assignments/mls/mls.xhtml#mls-extension-types
230 const (
231 extensionTypeApplicationID extensionType = 0x0001
232 extensionTypeRatchetTree extensionType = 0x0002
233 extensionTypeRequiredCapabilities extensionType = 0x0003
234 extensionTypeExternalPub extensionType = 0x0004
235 extensionTypeExternalSenders extensionType = 0x0005
236
237 // ExtensionTypeLastResort marks a KeyPackage as reusable for multiple
238 // Welcome messages. Required by Marmot (MIP-00).
239 ExtensionTypeLastResort extensionType = 0x000a
240
241 // ExtensionTypeNostrGroupData carries Nostr group metadata (group ID,
242 // name, admins, relays). Required by Marmot (MIP-01).
243 ExtensionTypeNostrGroupData extensionType = 0xf2ee
244 )
245
246 // Extension holds a TLS-serialized MLS extension (type + opaque data).
247 type Extension = extension
248
249 type extension struct {
250 extensionType extensionType
251 extensionData []byte
252 }
253
254 // NewExtension creates an extension with the given type and data.
255 func NewExtension(t extensionType, data []byte) extension {
256 return extension{extensionType: t, extensionData: data}
257 }
258
259 // ExtensionType is exported for use by the Marmot SDK.
260 type ExtensionType = extensionType
261
262 func unmarshalExtensionVec(s *cryptobyte.String) ([]extension, error) {
263 var exts []extension
264 err := readVector(s, func(s *cryptobyte.String) error {
265 var ext extension
266 if !s.ReadUint16((*uint16)(&ext.extensionType)) || !readOpaqueVec(s, &ext.extensionData) {
267 return io.ErrUnexpectedEOF
268 }
269 exts = append(exts, ext)
270 return nil
271 })
272 return exts, err
273 }
274
275 func marshalExtensionVec(b *cryptobyte.Builder, exts []extension) {
276 writeVector(b, len(exts), func(b *cryptobyte.Builder, i int) {
277 ext := exts[i]
278 b.AddUint16(uint16(ext.extensionType))
279 writeOpaqueVec(b, ext.extensionData)
280 })
281 }
282
283 func findExtensionData(exts []extension, t extensionType) []byte {
284 for _, ext := range exts {
285 if ext.extensionType == t {
286 return ext.extensionData
287 }
288 }
289 return nil
290 }
291
292 type leafNode struct {
293 encryptionKey hpkePublicKey
294 signatureKey signaturePublicKey
295 credential Credential
296 capabilities capabilities
297
298 leafNodeSource leafNodeSource
299 lifetime *lifetime // for leafNodeSourceKeyPackage
300 parentHash []byte // for leafNodeSourceCommit
301
302 extensions []extension
303 signature []byte
304 }
305
306 func (node *leafNode) unmarshal(s *cryptobyte.String) error {
307 *node = leafNode{}
308
309 if !readOpaqueVec(s, (*[]byte)(&node.encryptionKey)) || !readOpaqueVec(s, (*[]byte)(&node.signatureKey)) {
310 return io.ErrUnexpectedEOF
311 }
312
313 if err := node.credential.unmarshal(s); err != nil {
314 return err
315 }
316 if err := node.capabilities.unmarshal(s); err != nil {
317 return err
318 }
319 if err := node.leafNodeSource.unmarshal(s); err != nil {
320 return err
321 }
322
323 var err error
324 switch node.leafNodeSource {
325 case leafNodeSourceKeyPackage:
326 node.lifetime = new(lifetime)
327 err = node.lifetime.unmarshal(s)
328 case leafNodeSourceCommit:
329 if !readOpaqueVec(s, &node.parentHash) {
330 err = io.ErrUnexpectedEOF
331 }
332 }
333 if err != nil {
334 return err
335 }
336
337 exts, err := unmarshalExtensionVec(s)
338 if err != nil {
339 return err
340 }
341 node.extensions = exts
342
343 if !readOpaqueVec(s, &node.signature) {
344 return io.ErrUnexpectedEOF
345 }
346
347 return nil
348 }
349
350 func (node *leafNode) marshalBase(b *cryptobyte.Builder) {
351 writeOpaqueVec(b, []byte(node.encryptionKey))
352 writeOpaqueVec(b, []byte(node.signatureKey))
353 node.credential.marshal(b)
354 node.capabilities.marshal(b)
355 node.leafNodeSource.marshal(b)
356 switch node.leafNodeSource {
357 case leafNodeSourceKeyPackage:
358 node.lifetime.marshal(b)
359 case leafNodeSourceCommit:
360 writeOpaqueVec(b, node.parentHash)
361 }
362 marshalExtensionVec(b, node.extensions)
363 }
364
365 func (node *leafNode) marshal(b *cryptobyte.Builder) {
366 node.marshalBase(b)
367 writeOpaqueVec(b, []byte(node.signature))
368 }
369
370 type leafNodeTBS struct {
371 *leafNode
372
373 // for leafNodeSourceUpdate and leafNodeSourceCommit
374 groupID GroupID
375 leafIndex leafIndex
376 }
377
378 func (node *leafNodeTBS) marshal(b *cryptobyte.Builder) {
379 node.leafNode.marshalBase(b)
380 switch node.leafNode.leafNodeSource {
381 case leafNodeSourceUpdate, leafNodeSourceCommit:
382 writeOpaqueVec(b, []byte(node.groupID))
383 b.AddUint32(uint32(node.leafIndex))
384 }
385 }
386
387 func (node *leafNode) sign(cs CipherSuite, groupID GroupID, li leafIndex, signerPriv signaturePrivateKey) error {
388 leafNodeTBS, err := marshal(&leafNodeTBS{
389 leafNode: node,
390 groupID: groupID,
391 leafIndex: li,
392 })
393 if err != nil {
394 return err
395 }
396 sig, err := cs.signWithLabel(signerPriv, []byte("LeafNodeTBS"), leafNodeTBS)
397 if err != nil {
398 return err
399 }
400 node.signature = sig
401 return nil
402 }
403
404 // verifySignature verifies the signature of the leaf node.
405 //
406 // groupID and li can be left unspecified if the leaf node source is neither
407 // update nor commit.
408 func (node *leafNode) verifySignature(cs CipherSuite, groupID GroupID, li leafIndex) bool {
409 leafNodeTBS, err := marshal(&leafNodeTBS{
410 leafNode: node,
411 groupID: groupID,
412 leafIndex: li,
413 })
414 if err != nil {
415 return false
416 }
417 return cs.verifyWithLabel(node.signatureKey, []byte("LeafNodeTBS"), leafNodeTBS, node.signature)
418 }
419
420 // verify performs leaf node validation described in section 7.3.
421 //
422 // It does not perform all checks: it does not check that the credential is
423 // valid.
424 func (node *leafNode) verify(options *leafNodeVerifyOptions) error {
425 li := options.leafIndex
426
427 if !node.verifySignature(options.cipherSuite, options.groupID, li) {
428 return fmt.Errorf("mls: leaf node signature verification failed")
429 }
430
431 // TODO: check required_capabilities group extension
432
433 if _, ok := options.supportedCreds[node.credential.credentialType]; !ok {
434 return fmt.Errorf("mls: credential type %v used by leaf node not supported by all members", node.credential.credentialType)
435 }
436
437 if node.lifetime != nil {
438 now := options.now
439 if now == nil {
440 now = time.Now
441 }
442 if t := now(); !t.IsZero() && !node.lifetime.verify(t) {
443 return fmt.Errorf("mls: lifetime verification failed (not before %v, not after %v)", node.lifetime.notBeforeTime(), node.lifetime.notAfterTime())
444 }
445 }
446
447 supportedExts := make(map[extensionType]struct{})
448 for _, et := range node.capabilities.extensions {
449 supportedExts[et] = struct{}{}
450 }
451 for _, ext := range node.extensions {
452 if _, ok := supportedExts[ext.extensionType]; !ok {
453 return fmt.Errorf("mls: extension type %d used by leaf node not supported by that leaf node", ext.extensionType)
454 }
455 }
456
457 // TODO: verify the leaf_node_source field
458
459 if _, dup := options.signatureKeys[string(node.signatureKey)]; dup {
460 return fmt.Errorf("mls: duplicate signature key in ratchet tree")
461 }
462 if _, dup := options.encryptionKeys[string(node.encryptionKey)]; dup {
463 return fmt.Errorf("mls: duplicate encryption key in ratchet tree")
464 }
465
466 return nil
467 }
468
469 type leafNodeVerifyOptions struct {
470 cipherSuite CipherSuite
471 groupID GroupID
472 leafIndex leafIndex
473 supportedCreds map[credentialType]struct{}
474 signatureKeys map[string]struct{}
475 encryptionKeys map[string]struct{}
476 now func() time.Time
477 }
478
479 type updatePathNode struct {
480 encryptionKey hpkePublicKey
481 encryptedPathSecret []hpkeCiphertext
482 }
483
484 func (node *updatePathNode) unmarshal(s *cryptobyte.String) error {
485 *node = updatePathNode{}
486
487 if !readOpaqueVec(s, (*[]byte)(&node.encryptionKey)) {
488 return io.ErrUnexpectedEOF
489 }
490
491 return readVector(s, func(s *cryptobyte.String) error {
492 var ciphertext hpkeCiphertext
493 if err := ciphertext.unmarshal(s); err != nil {
494 return err
495 }
496 node.encryptedPathSecret = append(node.encryptedPathSecret, ciphertext)
497 return nil
498 })
499 }
500
501 func (node *updatePathNode) marshal(b *cryptobyte.Builder) {
502 writeOpaqueVec(b, []byte(node.encryptionKey))
503 writeVector(b, len(node.encryptedPathSecret), func(b *cryptobyte.Builder, i int) {
504 node.encryptedPathSecret[i].marshal(b)
505 })
506 }
507
508 func decryptPathSecret(cs CipherSuite, nodePriv hpkePrivateKey, ctx *groupContext, ciphertext hpkeCiphertext) ([]byte, error) {
509 rawCtx, err := marshal(ctx)
510 if err != nil {
511 return nil, err
512 }
513 return cs.decryptWithLabel(nodePriv, []byte("UpdatePathNode"), rawCtx, ciphertext.kemOutput, ciphertext.ciphertext)
514 }
515
516 func nodePrivFromPathSecret(cs CipherSuite, pathSecret []byte, nodePub hpkePublicKey) (hpkePrivateKey, error) {
517 nodeSecret, err := cs.deriveSecret(pathSecret, []byte("node"))
518 if err != nil {
519 return nil, err
520 }
521
522 pub, priv, err := cs.deriveEncryptionKeyPair(nodeSecret)
523 if err != nil {
524 return nil, err
525 }
526
527 if !bytes.Equal(pub, nodePub) {
528 return nil, fmt.Errorf("mls: node public key mismatch")
529 }
530
531 return priv, nil
532 }
533
534 type updatePath struct {
535 leafNode leafNode
536 nodes []updatePathNode
537 }
538
539 func (up *updatePath) unmarshal(s *cryptobyte.String) error {
540 *up = updatePath{}
541
542 if err := up.leafNode.unmarshal(s); err != nil {
543 return err
544 }
545
546 return readVector(s, func(s *cryptobyte.String) error {
547 var node updatePathNode
548 if err := node.unmarshal(s); err != nil {
549 return err
550 }
551 up.nodes = append(up.nodes, node)
552 return nil
553 })
554 }
555
556 func (up *updatePath) marshal(b *cryptobyte.Builder) {
557 up.leafNode.marshal(b)
558 writeVector(b, len(up.nodes), func(b *cryptobyte.Builder, i int) {
559 up.nodes[i].marshal(b)
560 })
561 }
562
563 type nodeType uint8
564
565 const (
566 nodeTypeLeaf nodeType = 1
567 nodeTypeParent nodeType = 2
568 )
569
570 func (t *nodeType) unmarshal(s *cryptobyte.String) error {
571 if !s.ReadUint8((*uint8)(t)) {
572 return io.ErrUnexpectedEOF
573 }
574 switch *t {
575 case nodeTypeLeaf, nodeTypeParent:
576 return nil
577 default:
578 return fmt.Errorf("mls: invalid node type %d", *t)
579 }
580 }
581
582 func (t nodeType) marshal(b *cryptobyte.Builder) {
583 b.AddUint8(uint8(t))
584 }
585
586 type node struct {
587 nodeType nodeType
588 leafNode *leafNode // for nodeTypeLeaf
589 parentNode *parentNode // for nodeTypeParent
590 }
591
592 func (n *node) unmarshal(s *cryptobyte.String) error {
593 *n = node{}
594
595 if err := n.nodeType.unmarshal(s); err != nil {
596 return err
597 }
598
599 switch n.nodeType {
600 case nodeTypeLeaf:
601 n.leafNode = new(leafNode)
602 return n.leafNode.unmarshal(s)
603 case nodeTypeParent:
604 n.parentNode = new(parentNode)
605 return n.parentNode.unmarshal(s)
606 default:
607 panic("unreachable")
608 }
609 }
610
611 func (n *node) marshal(b *cryptobyte.Builder) {
612 n.nodeType.marshal(b)
613 switch n.nodeType {
614 case nodeTypeLeaf:
615 n.leafNode.marshal(b)
616 case nodeTypeParent:
617 n.parentNode.marshal(b)
618 default:
619 panic("unreachable")
620 }
621 }
622
623 func (n *node) encryptionKey() hpkePublicKey {
624 switch n.nodeType {
625 case nodeTypeLeaf:
626 return n.leafNode.encryptionKey
627 case nodeTypeParent:
628 return n.parentNode.encryptionKey
629 default:
630 panic("unreachable")
631 }
632 }
633
634 // ratchetTree is a ratchet tree represented as complete balanced binary tree,
635 // stored with the array-based scheme described in appendix C.
636 //
637 // The length of the tree plus 1 is guaranteed to be a power of 2.
638 type ratchetTree []*node
639
640 func (tree *ratchetTree) unmarshal(s *cryptobyte.String) error {
641 *tree = ratchetTree{}
642 err := readVector(s, func(s *cryptobyte.String) error {
643 var n *node
644 var hasNode bool
645 if !readOptional(s, &hasNode) {
646 return io.ErrUnexpectedEOF
647 } else if hasNode {
648 n = new(node)
649 if err := n.unmarshal(s); err != nil {
650 return err
651 }
652 }
653 *tree = append(*tree, n)
654 return nil
655 })
656 if err != nil {
657 return err
658 }
659
660 // The raw tree doesn't include blank nodes at the end, fill it until next
661 // power of 2
662 for !isPowerOf2(uint32(len(*tree) + 1)) {
663 *tree = append(*tree, nil)
664 }
665
666 return nil
667 }
668
669 func (tree ratchetTree) marshal(b *cryptobyte.Builder) {
670 end := len(tree)
671 for end > 0 && tree[end-1] == nil {
672 end--
673 }
674
675 writeVector(b, len(tree[:end]), func(b *cryptobyte.Builder, i int) {
676 n := tree[i]
677 writeOptional(b, n != nil)
678 if n != nil {
679 n.marshal(b)
680 }
681 })
682 }
683
684 func (tree ratchetTree) copy() ratchetTree {
685 newTree := make(ratchetTree, len(tree))
686 copy(newTree, tree)
687 return newTree
688 }
689
690 // get returns the node at the provided index.
691 //
692 // nil is returned for blank nodes. get panics if the index is out of range.
693 func (tree ratchetTree) get(i nodeIndex) *node {
694 return tree[int(i)]
695 }
696
697 func (tree ratchetTree) set(i nodeIndex, node *node) {
698 tree[int(i)] = node
699 }
700
701 func (tree ratchetTree) getLeaf(li leafIndex) *leafNode {
702 node := tree.get(li.nodeIndex())
703 if node == nil {
704 return nil
705 }
706 if node.nodeType != nodeTypeLeaf {
707 panic("unreachable")
708 }
709 return node.leafNode
710 }
711
712 // resolve computes the resolution of a node.
713 func (tree ratchetTree) resolve(x nodeIndex) []nodeIndex {
714 n := tree.get(x)
715 if n == nil {
716 l, r, ok := x.children()
717 if !ok {
718 return nil // leaf
719 }
720 return append(tree.resolve(l), tree.resolve(r)...)
721 } else {
722 res := []nodeIndex{x}
723 if n.nodeType == nodeTypeParent {
724 for _, leafIndex := range n.parentNode.unmergedLeaves {
725 res = append(res, leafIndex.nodeIndex())
726 }
727 }
728 return res
729 }
730 }
731
732 func (tree ratchetTree) supportedCreds() map[credentialType]struct{} {
733 numMembers := 0
734 supportedCredsCount := make(map[credentialType]int)
735 for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ {
736 node := tree.getLeaf(li)
737 if node == nil {
738 continue
739 }
740
741 numMembers++
742 for _, ct := range node.capabilities.credentials {
743 supportedCredsCount[ct]++
744 }
745 }
746
747 supportedCreds := make(map[credentialType]struct{})
748 for ct, n := range supportedCredsCount {
749 if n == numMembers {
750 supportedCreds[ct] = struct{}{}
751 }
752 }
753
754 return supportedCreds
755 }
756
757 func (tree ratchetTree) keys() (signatureKeys, encryptionKeys map[string]struct{}) {
758 signatureKeys = make(map[string]struct{})
759 encryptionKeys = make(map[string]struct{})
760 for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ {
761 node := tree.getLeaf(li)
762 if node == nil {
763 continue
764 }
765 signatureKeys[string(node.signatureKey)] = struct{}{}
766 encryptionKeys[string(node.encryptionKey)] = struct{}{}
767 }
768 return signatureKeys, encryptionKeys
769 }
770
771 // verifyIntegrity verifies the integrity of the ratchet tree, as described in
772 // section 12.4.3.1.
773 //
774 // This function does not perform full leaf node validation. In particular:
775 //
776 // - It doesn't check that credentials are valid.
777 // - It doesn't check the lifetime field.
778 func (tree ratchetTree) verifyIntegrity(ctx *groupContext, now func() time.Time) error {
779 cs := ctx.cipherSuite
780 numLeaves := tree.numLeaves()
781
782 if h, err := tree.computeRootTreeHash(cs); err != nil {
783 return err
784 } else if !bytes.Equal(h, ctx.treeHash) {
785 return fmt.Errorf("mls: tree hash verification failed")
786 }
787
788 if !tree.verifyParentHashes(cs) {
789 return fmt.Errorf("mls: parent hashes verification failed")
790 }
791
792 supportedCreds := tree.supportedCreds()
793 signatureKeys := make(map[string]struct{})
794 encryptionKeys := make(map[string]struct{})
795 for li := leafIndex(0); li < leafIndex(numLeaves); li++ {
796 node := tree.getLeaf(li)
797 if node == nil {
798 continue
799 }
800
801 err := node.verify(&leafNodeVerifyOptions{
802 cipherSuite: cs,
803 groupID: ctx.groupID,
804 leafIndex: li,
805 supportedCreds: supportedCreds,
806 signatureKeys: signatureKeys,
807 encryptionKeys: encryptionKeys,
808 now: now,
809 })
810 if err != nil {
811 return fmt.Errorf("leaf node at index %v: %v", li, err)
812 }
813
814 signatureKeys[string(node.signatureKey)] = struct{}{}
815 encryptionKeys[string(node.encryptionKey)] = struct{}{}
816 }
817
818 for i, node := range tree {
819 if node == nil || node.nodeType != nodeTypeParent {
820 continue
821 }
822 p := nodeIndex(i)
823 for _, unmergedLeaf := range node.parentNode.unmergedLeaves {
824 x := unmergedLeaf.nodeIndex()
825 for {
826 var ok bool
827 if x, ok = numLeaves.parent(x); !ok {
828 return fmt.Errorf("mls: unmerged leaf %v is not a descendant of the parent node at index %v", unmergedLeaf, p)
829 } else if x == p {
830 break
831 }
832
833 intermediateNode := tree.get(x)
834 if intermediateNode != nil && !hasUnmergedLeaf(intermediateNode.parentNode, unmergedLeaf) {
835 return fmt.Errorf("mls: non-blank intermediate node at index %v is missing unmerged leaf %v", x, unmergedLeaf)
836 }
837 }
838 }
839
840 if _, dup := encryptionKeys[string(node.parentNode.encryptionKey)]; dup {
841 return fmt.Errorf("mls: duplicate encryption key in ratchet tree")
842 }
843 encryptionKeys[string(node.parentNode.encryptionKey)] = struct{}{}
844 }
845
846 return nil
847 }
848
849 func hasUnmergedLeaf(node *parentNode, unmergedLeaf leafIndex) bool {
850 for _, li := range node.unmergedLeaves {
851 if li == unmergedLeaf {
852 return true
853 }
854 }
855 return false
856 }
857
858 func (tree ratchetTree) computeRootTreeHash(cs CipherSuite) ([]byte, error) {
859 return tree.computeTreeHash(cs, tree.numLeaves().root(), nil)
860 }
861
862 func (tree ratchetTree) computeTreeHash(cs CipherSuite, x nodeIndex, exclude map[leafIndex]struct{}) ([]byte, error) {
863 n := tree.get(x)
864
865 var b cryptobyte.Builder
866 if li, ok := x.leafIndex(); ok {
867 _, excluded := exclude[li]
868
869 var l *leafNode
870 if n != nil && !excluded {
871 l = n.leafNode
872 if l == nil {
873 panic("unreachable")
874 }
875 }
876
877 marshalLeafNodeHashInput(&b, li, l)
878 } else {
879 left, right, ok := x.children()
880 if !ok {
881 panic("unreachable")
882 }
883
884 leftHash, err := tree.computeTreeHash(cs, left, exclude)
885 if err != nil {
886 return nil, err
887 }
888 rightHash, err := tree.computeTreeHash(cs, right, exclude)
889 if err != nil {
890 return nil, err
891 }
892
893 var p *parentNode
894 if n != nil {
895 p = n.parentNode
896 if p == nil {
897 panic("unreachable")
898 }
899
900 if len(p.unmergedLeaves) > 0 && len(exclude) > 0 {
901 unmergedLeaves := make([]leafIndex, 0, len(p.unmergedLeaves))
902 for _, li := range p.unmergedLeaves {
903 if _, excluded := exclude[li]; !excluded {
904 unmergedLeaves = append(unmergedLeaves, li)
905 }
906 }
907
908 filteredParent := *p
909 filteredParent.unmergedLeaves = unmergedLeaves
910 p = &filteredParent
911 }
912 }
913
914 marshalParentNodeHashInput(&b, p, leftHash, rightHash)
915 }
916 in, err := b.Bytes()
917 if err != nil {
918 return nil, err
919 }
920
921 h := cs.hash().New()
922 h.Write(in)
923 return h.Sum(nil), nil
924 }
925
926 func marshalLeafNodeHashInput(b *cryptobyte.Builder, i leafIndex, node *leafNode) {
927 b.AddUint8(uint8(nodeTypeLeaf))
928 b.AddUint32(uint32(i))
929 writeOptional(b, node != nil)
930 if node != nil {
931 node.marshal(b)
932 }
933 }
934
935 func marshalParentNodeHashInput(b *cryptobyte.Builder, node *parentNode, leftHash, rightHash []byte) {
936 b.AddUint8(uint8(nodeTypeParent))
937 writeOptional(b, node != nil)
938 if node != nil {
939 node.marshal(b)
940 }
941 writeOpaqueVec(b, leftHash)
942 writeOpaqueVec(b, rightHash)
943 }
944
945 func (tree ratchetTree) verifyParentHashes(cs CipherSuite) bool {
946 for i, node := range tree {
947 if node == nil {
948 continue
949 }
950
951 x := nodeIndex(i)
952 l, r, ok := x.children()
953 if !ok {
954 continue
955 }
956
957 parentNode := node.parentNode
958 exclude := make(map[leafIndex]struct{}, len(parentNode.unmergedLeaves))
959 for _, li := range parentNode.unmergedLeaves {
960 exclude[li] = struct{}{}
961 }
962
963 leftTreeHash, err := tree.computeTreeHash(cs, l, exclude)
964 if err != nil {
965 return false
966 }
967 rightTreeHash, err := tree.computeTreeHash(cs, r, exclude)
968 if err != nil {
969 return false
970 }
971
972 leftParentHash, err := parentNode.computeParentHash(cs, rightTreeHash)
973 if err != nil {
974 return false
975 }
976 rightParentHash, err := parentNode.computeParentHash(cs, leftTreeHash)
977 if err != nil {
978 return false
979 }
980
981 isLeftDescendant := tree.findParentHash(tree.resolve(l), leftParentHash)
982 isRightDescendant := tree.findParentHash(tree.resolve(r), rightParentHash)
983 if isLeftDescendant == isRightDescendant {
984 return false
985 }
986 }
987 return true
988 }
989
990 func (tree ratchetTree) findParentHash(nodeIndices []nodeIndex, parentHash []byte) bool {
991 for _, x := range nodeIndices {
992 node := tree.get(x)
993 if node == nil {
994 continue
995 }
996 var h []byte
997 switch node.nodeType {
998 case nodeTypeLeaf:
999 h = node.leafNode.parentHash
1000 case nodeTypeParent:
1001 h = node.parentNode.parentHash
1002 }
1003 if bytes.Equal(h, parentHash) {
1004 return true
1005 }
1006 }
1007 return false
1008 }
1009
1010 func (tree ratchetTree) numLeaves() numLeaves {
1011 return numLeavesFromWidth(uint32(len(tree)))
1012 }
1013
1014 func (tree ratchetTree) findLeaf(node *leafNode) (leafIndex, bool) {
1015 for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ {
1016 n := tree.getLeaf(li)
1017 if n == nil {
1018 continue
1019 }
1020
1021 // Encryption keys are unique
1022 if !bytes.Equal(n.encryptionKey, node.encryptionKey) {
1023 continue
1024 }
1025
1026 // Make sure both nodes are identical
1027 raw1, err1 := marshal(node)
1028 raw2, err2 := marshal(n)
1029 return li, err1 == nil && err2 == nil && bytes.Equal(raw1, raw2)
1030 }
1031 return 0, false
1032 }
1033
1034 func (tree *ratchetTree) add(leafNode *leafNode) {
1035 li := leafIndex(0)
1036 var ni nodeIndex
1037 found := false
1038 for {
1039 ni = li.nodeIndex()
1040 if int(ni) >= len(*tree) {
1041 break
1042 }
1043 if tree.get(ni) == nil {
1044 found = true
1045 break
1046 }
1047 li++
1048 }
1049 if !found {
1050 newLen := ((len(*tree) + 1) * 2) - 1
1051 for len(*tree) < newLen {
1052 *tree = append(*tree, nil)
1053 }
1054 }
1055
1056 numLeaves := tree.numLeaves()
1057 p := ni
1058 for {
1059 var ok bool
1060 p, ok = numLeaves.parent(p)
1061 if !ok {
1062 break
1063 }
1064 node := tree.get(p)
1065 if node != nil {
1066 node.parentNode.unmergedLeaves = append(node.parentNode.unmergedLeaves, li)
1067 }
1068 }
1069
1070 tree.set(ni, &node{
1071 nodeType: nodeTypeLeaf,
1072 leafNode: leafNode,
1073 })
1074 }
1075
1076 func (tree ratchetTree) update(li leafIndex, leafNode *leafNode) {
1077 ni := li.nodeIndex()
1078
1079 tree.set(ni, &node{
1080 nodeType: nodeTypeLeaf,
1081 leafNode: leafNode,
1082 })
1083
1084 numLeaves := tree.numLeaves()
1085 for {
1086 var ok bool
1087 ni, ok = numLeaves.parent(ni)
1088 if !ok {
1089 break
1090 }
1091
1092 tree.set(ni, nil)
1093 }
1094 }
1095
1096 func (tree *ratchetTree) remove(li leafIndex) {
1097 ni := li.nodeIndex()
1098
1099 numLeaves := tree.numLeaves()
1100 for {
1101 tree.set(ni, nil)
1102
1103 var ok bool
1104 ni, ok = numLeaves.parent(ni)
1105 if !ok {
1106 break
1107 }
1108 }
1109
1110 li = leafIndex(numLeaves - 1)
1111 lastPowerOf2 := len(*tree) + 1
1112 for {
1113 ni = li.nodeIndex()
1114 if tree.get(ni) != nil {
1115 break
1116 }
1117
1118 if isPowerOf2(uint32(ni)) {
1119 lastPowerOf2 = int(ni)
1120 }
1121
1122 if li == 0 {
1123 *tree = nil
1124 return
1125 }
1126 li--
1127 }
1128
1129 if lastPowerOf2 < len(*tree)+1 {
1130 *tree = (*tree)[:lastPowerOf2-1]
1131 }
1132 }
1133
1134 func (tree ratchetTree) filteredDirectPath(x nodeIndex) []nodeIndex {
1135 numLeaves := tree.numLeaves()
1136
1137 var path []nodeIndex
1138 for {
1139 p, ok := numLeaves.parent(x)
1140 if !ok {
1141 break
1142 }
1143
1144 s, ok := numLeaves.sibling(x)
1145 if !ok {
1146 panic("unreachable")
1147 }
1148
1149 if len(tree.resolve(s)) > 0 {
1150 path = append(path, p)
1151 }
1152
1153 x = p
1154 }
1155
1156 return path
1157 }
1158
1159 func (tree ratchetTree) mergeUpdatePath(cs CipherSuite, senderLeafIndex leafIndex, path *updatePath) error {
1160 senderNodeIndex := senderLeafIndex.nodeIndex()
1161 numLeaves := tree.numLeaves()
1162
1163 directPath := numLeaves.directPath(senderNodeIndex)
1164 for _, ni := range directPath {
1165 tree.set(ni, nil)
1166 }
1167
1168 filteredDirectPath := tree.filteredDirectPath(senderNodeIndex)
1169 if len(filteredDirectPath) != len(path.nodes) {
1170 return fmt.Errorf("mls: UpdatePath has %v nodes, but filtered direct path has %v nodes", len(path.nodes), len(filteredDirectPath))
1171 }
1172 for i, ni := range filteredDirectPath {
1173 pathNode := path.nodes[i]
1174 tree.set(ni, &node{
1175 nodeType: nodeTypeParent,
1176 parentNode: &parentNode{
1177 encryptionKey: pathNode.encryptionKey,
1178 },
1179 })
1180 }
1181
1182 // Compute parent hashes, from root to leaf
1183 var prevParentHash []byte
1184 for i := len(filteredDirectPath) - 1; i >= 0; i-- {
1185 ni := filteredDirectPath[i]
1186 node := tree.get(ni).parentNode
1187
1188 l, r, ok := ni.children()
1189 if !ok {
1190 panic("unreachable")
1191 }
1192
1193 s := l
1194 found := false
1195 for _, ni := range directPath {
1196 if ni == s {
1197 found = true
1198 break
1199 }
1200 }
1201 if s == senderNodeIndex || found {
1202 s = r
1203 }
1204
1205 treeHash, err := tree.computeTreeHash(cs, s, nil)
1206 if err != nil {
1207 return err
1208 }
1209
1210 node.parentHash = prevParentHash
1211 h, err := node.computeParentHash(cs, treeHash)
1212 if err != nil {
1213 return err
1214 }
1215 prevParentHash = h
1216 }
1217
1218 if !bytes.Equal(path.leafNode.parentHash, prevParentHash) {
1219 return fmt.Errorf("mls: parent hash mismatch for update path's leaf node")
1220 }
1221
1222 tree.set(senderNodeIndex, &node{
1223 nodeType: nodeTypeLeaf,
1224 leafNode: &path.leafNode,
1225 })
1226
1227 return nil
1228 }
1229
1230 func (tree ratchetTree) decryptPathSecrets(cs CipherSuite, groupCtx *groupContext, senderLeafIndex, recipientLeafIndex leafIndex, path *updatePath, privTree []hpkePrivateKey) ([]byte, error) {
1231 senderNodeIndex := senderLeafIndex.nodeIndex()
1232 recipientNodeIndex := recipientLeafIndex.nodeIndex()
1233
1234 senderFilteredDirectPath := tree.filteredDirectPath(senderNodeIndex)
1235 if len(path.nodes) != len(senderFilteredDirectPath) {
1236 return nil, fmt.Errorf("mls: invalid UpdatePath length")
1237 }
1238
1239 // Identify a node in the filtered direct path for which the recipient is
1240 // in the subtree of the non-updated child
1241 recipientAncestorIndex := -1
1242 recipientAncestor := commonAncestor(senderNodeIndex, recipientNodeIndex)
1243 for i, ni := range senderFilteredDirectPath {
1244 if ni == recipientAncestor {
1245 recipientAncestorIndex = i
1246 break
1247 }
1248 }
1249 if recipientAncestorIndex < 0 {
1250 return nil, fmt.Errorf("mls: cannot find recipient ancestor")
1251 }
1252 updatePathNode := path.nodes[recipientAncestorIndex]
1253
1254 // Find the copath node
1255 ancestor := commonAncestor(senderNodeIndex, recipientNodeIndex)
1256 var (
1257 copathNode nodeIndex
1258 ok bool
1259 )
1260 if recipientNodeIndex < senderNodeIndex {
1261 copathNode, ok = ancestor.left()
1262 } else {
1263 copathNode, ok = ancestor.right()
1264 }
1265 if !ok {
1266 panic("unreachable")
1267 }
1268
1269 copathResolution := tree.resolve(copathNode)
1270 if len(updatePathNode.encryptedPathSecret) != len(copathResolution) {
1271 return nil, fmt.Errorf("mls: invalid UpdatePathNode.encrypted_path_secret length")
1272 }
1273
1274 // Identify a node in the resolution of the copath node for which we have
1275 // a private key
1276 var nodePriv hpkePrivateKey
1277 resolutionIndex := -1
1278 for i, ni := range copathResolution {
1279 if p := privTree[int(ni)]; p != nil {
1280 nodePriv = p
1281 resolutionIndex = i
1282 break
1283 }
1284 }
1285 if nodePriv == nil {
1286 return nil, fmt.Errorf("mls: no private key found")
1287 }
1288 ciphertext := updatePathNode.encryptedPathSecret[resolutionIndex]
1289
1290 // Decrypt the path secret using the private key from the resolution node
1291 pathSecret, err := decryptPathSecret(cs, nodePriv, groupCtx, ciphertext)
1292 if err != nil {
1293 return nil, fmt.Errorf("failed to decrypt path secret: %v", err)
1294 }
1295 nodePub := tree.get(recipientAncestor).encryptionKey()
1296 nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, nodePub)
1297 if err != nil {
1298 return nil, fmt.Errorf("failed to derive node %v private key from path secret: %v", recipientAncestor, err)
1299 }
1300 privTree[int(recipientAncestor)] = nodePriv
1301
1302 // Derive path secrets for ancestors of that node in the sender's filtered
1303 // direct path
1304 for _, ni := range senderFilteredDirectPath[recipientAncestorIndex+1:] {
1305 pathSecret, err = cs.deriveSecret(pathSecret, []byte("path"))
1306 if err != nil {
1307 return nil, fmt.Errorf("failed to derive path secret: %v", err)
1308 }
1309 nodePriv, err := nodePrivFromPathSecret(cs, pathSecret, tree.get(ni).encryptionKey())
1310 if err != nil {
1311 return nil, fmt.Errorf("failed to derive node %v private key from path secret: %v", ni, err)
1312 }
1313 privTree[int(ni)] = nodePriv
1314 }
1315
1316 commitSecret, err := cs.deriveSecret(pathSecret, []byte("path"))
1317 if err != nil {
1318 return nil, fmt.Errorf("failed to derive commit secret: %v", err)
1319 }
1320
1321 return commitSecret, nil
1322 }
1323
1324 func (tree *ratchetTree) apply(proposals []proposal, senders []leafIndex) {
1325 // Apply all update proposals
1326 for i, prop := range proposals {
1327 if prop.proposalType == proposalTypeUpdate {
1328 tree.update(senders[i], &prop.update.leafNode)
1329 }
1330 }
1331
1332 // Apply all remove proposals
1333 for _, prop := range proposals {
1334 if prop.proposalType == proposalTypeRemove {
1335 tree.remove(prop.remove.removed)
1336 }
1337 }
1338
1339 // Apply all add proposals
1340 for _, prop := range proposals {
1341 if prop.proposalType == proposalTypeAdd {
1342 tree.add(&prop.add.keyPackage.leafNode)
1343 }
1344 }
1345 }
1346