package mls // MLS tree crypto operations (RFC 9420 §7). import "errors" var ( errNodeKeyMismatch = errors.New("mls: node public key mismatch") ) func (node *parentNode) computeParentHash(cs CipherSuite, originalSiblingTreeHash []byte) ([]byte, error) { raw, err := marshalParentHashInput(node.encryptionKey, node.parentHash, originalSiblingTreeHash) if err != nil { return nil, err } return cs.hash(raw), nil } func marshalParentHashInput(encryptionKey hpkePublicKey, parentHash, originalSiblingTreeHash []byte) ([]byte, error) { var w Writer w.writeOpaqueVec([]byte(encryptionKey)) w.writeOpaqueVec(parentHash) w.writeOpaqueVec(originalSiblingTreeHash) return w.bytes() } func (node *leafNode) sign(cs CipherSuite, groupID GroupID, li leafIndex, signerPriv signaturePrivateKey) error { tbs, err := marshalRaw(&leafNodeTBS{ node: node, groupID: groupID, leafIndex: li, }) if err != nil { return err } sig, err := cs.signWithLabel(signerPriv, []byte("LeafNodeTBS"), tbs) if err != nil { return err } node.signature = sig return nil } func (node *leafNode) verifySignature(cs CipherSuite, groupID GroupID, li leafIndex) bool { tbs, err := marshalRaw(&leafNodeTBS{ node: node, groupID: groupID, leafIndex: li, }) if err != nil { return false } return cs.verifyWithLabel(node.signatureKey, []byte("LeafNodeTBS"), tbs, node.signature) } func decryptPathSecret(cs CipherSuite, nodePriv hpkePrivateKey, ctx *groupContext, ct hpkeCiphertext) ([]byte, error) { rawCtx, err := marshalRaw(ctx) if err != nil { return nil, err } return cs.decryptWithLabel(nodePriv, []byte("UpdatePathNode"), rawCtx, ct.kemOutput, ct.ciphertext) } func nodePrivFromPathSecret(cs CipherSuite, pathSecret []byte, nodePub hpkePublicKey) (hpkePrivateKey, error) { nodeSecret, err := cs.deriveSecret(pathSecret, []byte("node")) if err != nil { return nil, err } pub, priv, err := cs.deriveEncryptionKeyPair(nodeSecret) if err != nil { return nil, err } if !bytesEqual(pub, nodePub) { return nil, errNodeKeyMismatch } return priv, nil } // computeRootTreeHash computes the tree hash for integrity verification. func (tree ratchetTree) computeRootTreeHash(cs CipherSuite) ([]byte, error) { n := tree.numLeaves() return tree.computeTreeHash(cs, n.root(), n, nil) } func (tree ratchetTree) computeTreeHash(cs CipherSuite, x nodeIndex, n numLeaves, exclude map[leafIndex]bool) ([]byte, error) { if x.isLeaf() { return tree.computeLeafTreeHash(cs, x, exclude) } return tree.computeParentTreeHash(cs, x, n, exclude) } func (tree ratchetTree) computeLeafTreeHash(cs CipherSuite, x nodeIndex, exclude map[leafIndex]bool) ([]byte, error) { var w Writer li, _ := x.leafIndex() w.addUint32(uint32(li)) nd := tree.get(x) if exclude[li] { nd = nil } w.writeOptional(nd != nil) if nd != nil { nd.leafNode.marshal(&w) } input, err := w.bytes() if err != nil { return nil, err } return cs.hash(input), nil } func (tree ratchetTree) computeParentTreeHash(cs CipherSuite, x nodeIndex, n numLeaves, exclude map[leafIndex]bool) ([]byte, error) { l, r, ok := x.children() if !ok { return nil, errUnexpectedEOF } leftHash, err := tree.computeTreeHash(cs, l, n, exclude) if err != nil { return nil, err } rightHash, err := tree.computeTreeHash(cs, r, n, exclude) if err != nil { return nil, err } var w Writer nd := tree.get(x) if nd != nil && len(exclude) > 0 { // Filter unmerged leaves for parent hash verification pn := nd.parentNode if pn != nil && len(pn.unmergedLeaves) > 0 { filtered := []leafIndex{:0:len(pn.unmergedLeaves)} for _, li := range pn.unmergedLeaves { if !exclude[li] { filtered = append(filtered, li) } } filteredNode := *pn filteredNode.unmergedLeaves = filtered w.writeOptional(true) filteredNode.marshal(&w) w.writeOpaqueVec(leftHash) w.writeOpaqueVec(rightHash) input, err := w.bytes() if err != nil { return nil, err } return cs.hash(input), nil } } w.writeOptional(nd != nil) if nd != nil { nd.parentNode.marshal(&w) } w.writeOpaqueVec(leftHash) w.writeOpaqueVec(rightHash) input, err := w.bytes() if err != nil { return nil, err } return cs.hash(input), nil } // verifyParentHashes verifies the parent hash chain (RFC 9420 §7.9.2). func (tree ratchetTree) verifyParentHashes(cs CipherSuite) bool { n := tree.numLeaves() for i, nd := range tree { if nd == nil { continue } x := nodeIndex(i) l, r, ok := x.children() if !ok { continue } pn := nd.parentNode exclude := map[leafIndex]bool{} for _, li := range pn.unmergedLeaves { exclude[li] = true } leftTreeHash, err := tree.computeTreeHash(cs, l, n, exclude) if err != nil { return false } rightTreeHash, err := tree.computeTreeHash(cs, r, n, exclude) if err != nil { return false } leftParentHash, err := pn.computeParentHash(cs, rightTreeHash) if err != nil { return false } rightParentHash, err := pn.computeParentHash(cs, leftTreeHash) if err != nil { return false } isLeft := tree.findParentHash(tree.resolve(l), leftParentHash) isRight := tree.findParentHash(tree.resolve(r), rightParentHash) if isLeft == isRight { return false } } return true } // verifyIntegrity verifies the ratchet tree (RFC 9420 §12.4.3.1). func (tree ratchetTree) verifyIntegrity(ctx *groupContext, nowUnix int64) error { cs := ctx.cipherSuite n := tree.numLeaves() h, err := tree.computeRootTreeHash(cs) if err != nil { return err } if !bytesEqual(h, ctx.treeHash) { return errors.New("mls: tree hash verification failed") } if !tree.verifyParentHashes(cs) { return errors.New("mls: parent hash verification failed") } supportedCreds := tree.supportedCreds() sigKeys := map[string]bool{} encKeys := map[string]bool{} for li := leafIndex(0); li < leafIndex(n); li++ { nd := tree.getLeaf(li) if nd == nil { continue } err := nd.verify(&leafNodeVerifyOptions{ cipherSuite: cs, groupID: ctx.groupID, leafIndex: li, supportedCreds: supportedCreds, signatureKeys: sigKeys, encryptionKeys: encKeys, nowUnix: nowUnix, }) if err != nil { return err } sigKeys[string(nd.signatureKey)] = true encKeys[string(nd.encryptionKey)] = true } // Check unmerged leaf ancestry for i, nd := range tree { if nd == nil || nd.nodeType != nodeTypeParent { continue } p := nodeIndex(i) for _, ul := range nd.parentNode.unmergedLeaves { x := ul.nodeIndex() for { var ok bool x, ok = n.parent(x) if !ok { return errors.New("mls: unmerged leaf not descendant of parent") } if x == p { break } intermediate := tree.get(x) if intermediate != nil && !hasUnmergedLeaf(intermediate.parentNode, ul) { return errors.New("mls: intermediate node missing unmerged leaf") } } } if encKeys[string(nd.parentNode.encryptionKey)] { return errors.New("mls: duplicate encryption key in tree") } encKeys[string(nd.parentNode.encryptionKey)] = true } return nil } // mergeUpdatePath applies an update path to the tree (RFC 9420 §7.5). func (tree ratchetTree) mergeUpdatePath(cs CipherSuite, senderLI leafIndex, path *updatePath) error { senderNI := senderLI.nodeIndex() n := tree.numLeaves() directPath := n.directPath(senderNI) for _, ni := range directPath { tree.set(ni, nil) } filteredDP := tree.filteredDirectPath(senderNI) if len(filteredDP) != len(path.nodes) { return errors.New("mls: update path length mismatch") } for i, ni := range filteredDP { tree.set(ni, &node{ nodeType: nodeTypeParent, parentNode: &parentNode{ encryptionKey: path.nodes[i].encryptionKey, }, }) } // Compute parent hashes root-to-leaf var prevParentHash []byte for i := len(filteredDP) - 1; i >= 0; i-- { ni := filteredDP[i] pn := tree.get(ni).parentNode l, r, ok := ni.children() if !ok { panic("unreachable") } s := l found := false for _, dp := range directPath { if dp == s { found = true break } } if s == senderNI || found { s = r } treeHash, err := tree.computeTreeHash(cs, s, n, nil) if err != nil { return err } pn.parentHash = prevParentHash h, err := pn.computeParentHash(cs, treeHash) if err != nil { return err } prevParentHash = h } if !bytesEqual(path.leafNode.parentHash, prevParentHash) { return errors.New("mls: parent hash mismatch for update path leaf node") } tree.set(senderNI, &node{ nodeType: nodeTypeLeaf, leafNode: &path.leafNode, }) return nil } // decryptPathSecrets decrypts path secrets from an update path (RFC 9420 §7.6). func (tree ratchetTree) decryptPathSecrets(cs CipherSuite, ctx *groupContext, senderLI, recipientLI leafIndex, path *updatePath, privTree []hpkePrivateKey) ([]byte, error) { senderNI := senderLI.nodeIndex() recipientNI := recipientLI.nodeIndex() senderFDP := tree.filteredDirectPath(senderNI) if len(path.nodes) != len(senderFDP) { return nil, errors.New("mls: invalid update path length") } // Find the common ancestor in the filtered direct path recipientAncestor := commonAncestor(senderNI, recipientNI) recipientAncestorIdx := -1 for i, ni := range senderFDP { if ni == recipientAncestor { recipientAncestorIdx = i break } } if recipientAncestorIdx < 0 { return nil, errors.New("mls: cannot find recipient ancestor") } upNode := path.nodes[recipientAncestorIdx] // Find the copath node ancestor := commonAncestor(senderNI, recipientNI) var copathNode nodeIndex var ok bool if recipientNI < senderNI { copathNode, ok = ancestor.left() } else { copathNode, ok = ancestor.right() } if !ok { panic("unreachable") } copathRes := tree.resolve(copathNode) if len(upNode.encryptedPathSecret) != len(copathRes) { return nil, errors.New("mls: invalid encrypted path secret length") } // Find a node in the resolution for which we have a private key var nodePriv hpkePrivateKey resIdx := -1 for i, ni := range copathRes { if p := privTree[int(ni)]; p != nil { nodePriv = p resIdx = i break } } if nodePriv == nil { return nil, errors.New("mls: no private key found") } pathSecret, err := decryptPathSecret(cs, nodePriv, ctx, upNode.encryptedPathSecret[resIdx]) if err != nil { return nil, err } nodePub := tree.get(recipientAncestor).encryptionKey() nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, nodePub) if err != nil { return nil, err } privTree[int(recipientAncestor)] = nodePriv // Derive path secrets for remaining ancestors for _, ni := range senderFDP[recipientAncestorIdx+1:] { pathSecret, err = cs.deriveSecret(pathSecret, []byte("path")) if err != nil { return nil, err } nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, tree.get(ni).encryptionKey()) if err != nil { return nil, err } privTree[int(ni)] = nodePriv } commitSecret, err := cs.deriveSecret(pathSecret, []byte("path")) if err != nil { return nil, err } return commitSecret, nil }