package mls // MLS ratchet tree types (RFC 9420 §7). // Data types and serialization only — crypto operations (sign, verify, // parentHash) go in tree_crypto.mx with the cipher suite. import "errors" var ( errInvalidLeafNodeSource = errors.New("mls: invalid leaf node source") errInvalidNodeType = errors.New("mls: invalid node type") ) func bytesEqual(a, b []byte) bool { if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true } // --- ParentNode --- type parentNode struct { encryptionKey hpkePublicKey parentHash []byte unmergedLeaves []leafIndex } func (node *parentNode) unmarshal(r *Reader) error { *node = parentNode{} var ok bool node.encryptionKey, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } node.parentHash, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return r.readVector(func(r *Reader) error { v, ok := r.readUint32() if !ok { return errUnexpectedEOF } node.unmergedLeaves = append(node.unmergedLeaves, leafIndex(v)) return nil }) } func (node *parentNode) marshal(w *Writer) { w.writeOpaqueVec([]byte(node.encryptionKey)) w.writeOpaqueVec(node.parentHash) w.writeVector(len(node.unmergedLeaves), func(w *Writer, i int) { w.addUint32(uint32(node.unmergedLeaves[i])) }) } // --- LeafNodeSource --- type leafNodeSource uint8 const ( leafNodeSourceKeyPackage leafNodeSource = 1 leafNodeSourceUpdate leafNodeSource = 2 leafNodeSourceCommit leafNodeSource = 3 ) func (src *leafNodeSource) unmarshal(r *Reader) error { b, ok := r.readByte() if !ok { return errUnexpectedEOF } *src = leafNodeSource(b) switch *src { case leafNodeSourceKeyPackage, leafNodeSourceUpdate, leafNodeSourceCommit: return nil default: return errInvalidLeafNodeSource } } func (src leafNodeSource) marshal(w *Writer) { w.addByte(byte(src)) } // --- Capabilities --- type capabilities struct { versions []protocolVersion cipherSuites []CipherSuite extensions []extensionType proposals []proposalType credentials []credentialType } func (caps *capabilities) unmarshal(r *Reader) error { *caps = capabilities{} err := r.readVector(func(r *Reader) error { v, ok := r.readUint16() if !ok { return errUnexpectedEOF } caps.versions = append(caps.versions, protocolVersion(v)) return nil }) if err != nil { return err } err = r.readVector(func(r *Reader) error { v, ok := r.readUint16() if !ok { return errUnexpectedEOF } caps.cipherSuites = append(caps.cipherSuites, CipherSuite(v)) return nil }) if err != nil { return err } err = r.readVector(func(r *Reader) error { v, ok := r.readUint16() if !ok { return errUnexpectedEOF } caps.extensions = append(caps.extensions, extensionType(v)) return nil }) if err != nil { return err } err = r.readVector(func(r *Reader) error { v, ok := r.readUint16() if !ok { return errUnexpectedEOF } caps.proposals = append(caps.proposals, proposalType(v)) return nil }) if err != nil { return err } return r.readVector(func(r *Reader) error { v, ok := r.readUint16() if !ok { return errUnexpectedEOF } caps.credentials = append(caps.credentials, credentialType(v)) return nil }) } func (caps *capabilities) marshal(w *Writer) { w.writeVector(len(caps.versions), func(w *Writer, i int) { w.addUint16(uint16(caps.versions[i])) }) w.writeVector(len(caps.cipherSuites), func(w *Writer, i int) { w.addUint16(uint16(caps.cipherSuites[i])) }) w.writeVector(len(caps.extensions), func(w *Writer, i int) { w.addUint16(uint16(caps.extensions[i])) }) w.writeVector(len(caps.proposals), func(w *Writer, i int) { w.addUint16(uint16(caps.proposals[i])) }) w.writeVector(len(caps.credentials), func(w *Writer, i int) { w.addUint16(uint16(caps.credentials[i])) }) } // --- Lifetime --- type lifetime struct { notBefore, notAfter uint64 } func (lt *lifetime) unmarshal(r *Reader) error { *lt = lifetime{} var ok bool lt.notBefore, ok = r.readUint64() if !ok { return errUnexpectedEOF } lt.notAfter, ok = r.readUint64() if !ok { return errUnexpectedEOF } return nil } func (lt *lifetime) marshal(w *Writer) { w.addUint64(lt.notBefore) w.addUint64(lt.notAfter) } // --- Extension --- type extensionType uint16 const ( extensionTypeApplicationID extensionType = 0x0001 extensionTypeRatchetTree extensionType = 0x0002 extensionTypeRequiredCapabilities extensionType = 0x0003 extensionTypeExternalPub extensionType = 0x0004 extensionTypeExternalSenders extensionType = 0x0005 // Marmot: KeyPackage reusable for multiple Welcomes (MIP-00) ExtensionTypeLastResort extensionType = 0x000a // Marmot: Nostr group metadata (MIP-01) ExtensionTypeNostrGroupData extensionType = 0xf2ee ) type Extension = extension type extension struct { extensionType extensionType extensionData []byte } func NewExtension(t extensionType, data []byte) extension { return extension{extensionType: t, extensionData: data} } type ExtensionType = extensionType func unmarshalExtensionVec(r *Reader) ([]extension, error) { var exts []extension err := r.readVector(func(r *Reader) error { var ext extension v, ok := r.readUint16() if !ok { return errUnexpectedEOF } ext.extensionType = extensionType(v) ext.extensionData, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } exts = append(exts, ext) return nil }) return exts, err } func marshalExtensionVec(w *Writer, exts []extension) { w.writeVector(len(exts), func(w *Writer, i int) { ext := exts[i] w.addUint16(uint16(ext.extensionType)) w.writeOpaqueVec(ext.extensionData) }) } func findExtensionData(exts []extension, t extensionType) []byte { for _, ext := range exts { if ext.extensionType == t { return ext.extensionData } } return nil } // --- LeafNode --- type leafNode struct { encryptionKey hpkePublicKey signatureKey signaturePublicKey credential Credential capabilities capabilities leafNodeSource leafNodeSource lifetime *lifetime // for leafNodeSourceKeyPackage parentHash []byte // for leafNodeSourceCommit extensions []extension signature []byte } func (node *leafNode) unmarshal(r *Reader) error { *node = leafNode{} var ok bool node.encryptionKey, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } node.signatureKey, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } if err := node.credential.unmarshal(r); err != nil { return err } if err := node.capabilities.unmarshal(r); err != nil { return err } if err := node.leafNodeSource.unmarshal(r); err != nil { return err } var err error switch node.leafNodeSource { case leafNodeSourceKeyPackage: node.lifetime = &lifetime{} err = node.lifetime.unmarshal(r) case leafNodeSourceCommit: node.parentHash, ok = r.readOpaqueVec() if !ok { err = errUnexpectedEOF } } if err != nil { return err } exts, err := unmarshalExtensionVec(r) if err != nil { return err } node.extensions = exts node.signature, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return nil } func (node *leafNode) marshalBase(w *Writer) { w.writeOpaqueVec([]byte(node.encryptionKey)) w.writeOpaqueVec([]byte(node.signatureKey)) node.credential.marshal(w) node.capabilities.marshal(w) node.leafNodeSource.marshal(w) switch node.leafNodeSource { case leafNodeSourceKeyPackage: node.lifetime.marshal(w) case leafNodeSourceCommit: w.writeOpaqueVec(node.parentHash) } marshalExtensionVec(w, node.extensions) } func (node *leafNode) marshal(w *Writer) { node.marshalBase(w) w.writeOpaqueVec(node.signature) } // --- LeafNodeTBS --- type leafNodeTBS struct { node *leafNode // for leafNodeSourceUpdate and leafNodeSourceCommit groupID GroupID leafIndex leafIndex } func (tbs *leafNodeTBS) marshal(w *Writer) { tbs.node.marshalBase(w) switch tbs.node.leafNodeSource { case leafNodeSourceUpdate, leafNodeSourceCommit: w.writeOpaqueVec([]byte(tbs.groupID)) w.addUint32(uint32(tbs.leafIndex)) } } // --- UpdatePathNode --- type updatePathNode struct { encryptionKey hpkePublicKey encryptedPathSecret []hpkeCiphertext } func (node *updatePathNode) unmarshal(r *Reader) error { *node = updatePathNode{} var ok bool node.encryptionKey, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return r.readVector(func(r *Reader) error { var ct hpkeCiphertext if err := ct.unmarshal(r); err != nil { return err } node.encryptedPathSecret = append(node.encryptedPathSecret, ct) return nil }) } func (node *updatePathNode) marshal(w *Writer) { w.writeOpaqueVec([]byte(node.encryptionKey)) w.writeVector(len(node.encryptedPathSecret), func(w *Writer, i int) { node.encryptedPathSecret[i].marshal(w) }) } // --- UpdatePath --- type updatePath struct { leafNode leafNode nodes []updatePathNode } func (up *updatePath) unmarshal(r *Reader) error { *up = updatePath{} if err := up.leafNode.unmarshal(r); err != nil { return err } return r.readVector(func(r *Reader) error { var node updatePathNode if err := node.unmarshal(r); err != nil { return err } up.nodes = append(up.nodes, node) return nil }) } func (up *updatePath) marshal(w *Writer) { up.leafNode.marshal(w) w.writeVector(len(up.nodes), func(w *Writer, i int) { up.nodes[i].marshal(w) }) } // --- NodeType --- type nodeType uint8 const ( nodeTypeLeaf nodeType = 1 nodeTypeParent nodeType = 2 ) func (t *nodeType) unmarshal(r *Reader) error { b, ok := r.readByte() if !ok { return errUnexpectedEOF } *t = nodeType(b) switch *t { case nodeTypeLeaf, nodeTypeParent: return nil default: return errInvalidNodeType } } func (t nodeType) marshal(w *Writer) { w.addByte(byte(t)) } // --- Node --- type node struct { nodeType nodeType leafNode *leafNode // for nodeTypeLeaf parentNode *parentNode // for nodeTypeParent } func (n *node) unmarshal(r *Reader) error { *n = node{} if err := n.nodeType.unmarshal(r); err != nil { return err } switch n.nodeType { case nodeTypeLeaf: n.leafNode = &leafNode{} return n.leafNode.unmarshal(r) case nodeTypeParent: n.parentNode = &parentNode{} return n.parentNode.unmarshal(r) default: panic("unreachable") } } func (n *node) marshal(w *Writer) { n.nodeType.marshal(w) switch n.nodeType { case nodeTypeLeaf: n.leafNode.marshal(w) case nodeTypeParent: n.parentNode.marshal(w) default: panic("unreachable") } } func (n *node) encryptionKey() hpkePublicKey { switch n.nodeType { case nodeTypeLeaf: return n.leafNode.encryptionKey case nodeTypeParent: return n.parentNode.encryptionKey default: panic("unreachable") } } // --- RatchetTree --- type ratchetTree []*node func (tree *ratchetTree) unmarshal(r *Reader) error { *tree = ratchetTree{} err := r.readVector(func(r *Reader) error { present, ok := r.readOptional() if !ok { return errUnexpectedEOF } if present { n := &node{} if err := n.unmarshal(r); err != nil { return err } *tree = append(*tree, n) } else { *tree = append(*tree, nil) } return nil }) if err != nil { return err } // Pad to next power of 2 (width + 1 must be power of 2) for !isPowerOf2(uint(len(*tree) + 1)) { *tree = append(*tree, nil) } return nil } func (tree ratchetTree) marshal(w *Writer) { end := len(tree) for end > 0 && tree[end-1] == nil { end-- } w.writeVector(len(tree[:end]), func(w *Writer, i int) { n := tree[i] w.writeOptional(n != nil) if n != nil { n.marshal(w) } }) } func (tree ratchetTree) numLeaves() numLeaves { return numLeavesFromWidth(uint(len(tree))) } func (tree ratchetTree) get(i nodeIndex) *node { return tree[int(i)] } func (tree ratchetTree) set(i nodeIndex, nd *node) { tree[int(i)] = nd } func (tree ratchetTree) getLeaf(li leafIndex) *leafNode { nd := tree.get(li.nodeIndex()) if nd == nil { return nil } return nd.leafNode } func (tree ratchetTree) resolve(x nodeIndex) []nodeIndex { n := tree.get(x) if n == nil { l, r, ok := x.children() if !ok { return nil } return append(tree.resolve(l), tree.resolve(r)...) } res := []nodeIndex{x} if n.nodeType == nodeTypeParent { for _, li := range n.parentNode.unmergedLeaves { res = append(res, li.nodeIndex()) } } return res } func (tree ratchetTree) copy() ratchetTree { newTree := ratchetTree([]*node{:len(tree)}) for i, nd := range tree { newTree[i] = nd } return newTree } func (tree *ratchetTree) add(ln *leafNode) { li := leafIndex(0) var ni nodeIndex found := false for { ni = li.nodeIndex() if int(ni) >= len(*tree) { break } if tree.get(ni) == nil { found = true break } li++ } if !found { newLen := ((len(*tree) + 1) * 2) - 1 for len(*tree) < newLen { *tree = append(*tree, nil) } } n := tree.numLeaves() p := ni for { var ok bool p, ok = n.parent(p) if !ok { break } nd := tree.get(p) if nd != nil { nd.parentNode.unmergedLeaves = append(nd.parentNode.unmergedLeaves, li) } } tree.set(ni, &node{ nodeType: nodeTypeLeaf, leafNode: ln, }) } func (tree ratchetTree) update(li leafIndex, ln *leafNode) { ni := li.nodeIndex() tree.set(ni, &node{ nodeType: nodeTypeLeaf, leafNode: ln, }) n := tree.numLeaves() for { var ok bool ni, ok = n.parent(ni) if !ok { break } tree.set(ni, nil) } } func (tree *ratchetTree) remove(li leafIndex) { ni := li.nodeIndex() n := tree.numLeaves() for { tree.set(ni, nil) var ok bool ni, ok = n.parent(ni) if !ok { break } } li = leafIndex(n - 1) lastPowerOf2 := len(*tree) + 1 for { ni = li.nodeIndex() if tree.get(ni) != nil { break } if isPowerOf2(uint(ni)) { lastPowerOf2 = int(ni) } if li == 0 { *tree = nil return } li-- } if lastPowerOf2 < len(*tree)+1 { *tree = (*tree)[:lastPowerOf2-1] } } func (tree *ratchetTree) apply(proposals []proposal, senders []leafIndex) { for i, prop := range proposals { if prop.proposalType == proposalTypeUpdate { tree.update(senders[i], &prop.update.leafNode) } } for _, prop := range proposals { if prop.proposalType == proposalTypeRemove { tree.remove(prop.remove.removed) } } for _, prop := range proposals { if prop.proposalType == proposalTypeAdd { tree.add(&prop.add.keyPackage.leafNode) } } } func (tree ratchetTree) findLeaf(ln *leafNode) (leafIndex, bool) { for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ { nd := tree.getLeaf(li) if nd == nil { continue } if !bytesEqual(nd.encryptionKey, ln.encryptionKey) { continue } raw1, err1 := marshalRaw(ln) raw2, err2 := marshalRaw(nd) return li, err1 == nil && err2 == nil && bytesEqual(raw1, raw2) } return 0, false } func (tree ratchetTree) keys() (sigKeys, encKeys map[string]bool) { sigKeys = map[string]bool{} encKeys = map[string]bool{} for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ { nd := tree.getLeaf(li) if nd == nil { continue } sigKeys[string(nd.signatureKey)] = true encKeys[string(nd.encryptionKey)] = true } return sigKeys, encKeys } func (tree ratchetTree) supportedCreds() map[credentialType]bool { numMembers := 0 counts := map[credentialType]int{} for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ { nd := tree.getLeaf(li) if nd == nil { continue } numMembers++ for _, ct := range nd.capabilities.credentials { counts[ct]++ } } result := map[credentialType]bool{} for ct, c := range counts { if c == numMembers { result[ct] = true } } return result } func (tree ratchetTree) filteredDirectPath(x nodeIndex) []nodeIndex { n := tree.numLeaves() var path []nodeIndex for { p, ok := n.parent(x) if !ok { break } s, ok := n.sibling(x) if !ok { panic("unreachable") } if len(tree.resolve(s)) > 0 { path = append(path, p) } x = p } return path } func hasUnmergedLeaf(pn *parentNode, target leafIndex) bool { for _, li := range pn.unmergedLeaves { if li == target { return true } } return false } func (tree ratchetTree) findParentHash(nodeIndices []nodeIndex, ph []byte) bool { for _, x := range nodeIndices { nd := tree.get(x) if nd == nil { continue } var h []byte switch nd.nodeType { case nodeTypeLeaf: h = nd.leafNode.parentHash case nodeTypeParent: h = nd.parentNode.parentHash } if bytesEqual(h, ph) { return true } } return false } // --- LeafNode verification --- type leafNodeVerifyOptions struct { cipherSuite CipherSuite groupID GroupID leafIndex leafIndex supportedCreds map[credentialType]bool signatureKeys map[string]bool encryptionKeys map[string]bool nowUnix int64 // 0 = skip lifetime check } func (ln *leafNode) verify(opts *leafNodeVerifyOptions) error { if !ln.verifySignature(opts.cipherSuite, opts.groupID, opts.leafIndex) { return errors.New("mls: leaf node signature verification failed") } if !opts.supportedCreds[ln.credential.credentialType] { return errors.New("mls: credential type not supported by all members") } if ln.lifetime != nil && opts.nowUnix != 0 { if !ln.lifetime.verifyAt(opts.nowUnix) { return errors.New("mls: lifetime verification failed") } } supportedExts := map[extensionType]bool{} for _, et := range ln.capabilities.extensions { supportedExts[et] = true } for _, ext := range ln.extensions { if !supportedExts[ext.extensionType] { return errors.New("mls: extension type not supported by leaf node") } } if opts.signatureKeys[string(ln.signatureKey)] { return errors.New("mls: duplicate signature key") } if opts.encryptionKeys[string(ln.encryptionKey)] { return errors.New("mls: duplicate encryption key") } return nil } const maxLeafNodeLifetime = 90 * 24 * 3600 // 90 days in seconds func (lt *lifetime) verifyAt(nowUnix int64) bool { notBefore := int64(lt.notBefore) notAfter := int64(lt.notAfter) duration := notAfter - notBefore if duration <= 0 || duration > maxLeafNodeLifetime { return false } return nowUnix > notBefore && notAfter > nowUnix }