package mls // MLS proposals (RFC 9420 §12). import "errors" var ( errInvalidProposalType = errors.New("mls: invalid proposal type") errDupAddKey = errors.New("mls: duplicate add proposal signature key") errUpdateByCommitter = errors.New("mls: update proposal from committer") errDupUpdateOrRemove = errors.New("mls: duplicate update/remove on same leaf") errRemoveCommitter = errors.New("mls: remove proposal targets committer") errDupPSK = errors.New("mls: duplicate PSK proposal") errDupGroupContextExts = errors.New("mls: duplicate group context extensions proposal") errReinitWithOther = errors.New("mls: reinit with other proposals") errExternalInitNotAllowed = errors.New("mls: external init not allowed") ) // --- ProposalType --- type proposalType uint16 const ( proposalTypeAdd proposalType = 0x0001 proposalTypeUpdate proposalType = 0x0002 proposalTypeRemove proposalType = 0x0003 proposalTypePSK proposalType = 0x0004 proposalTypeReinit proposalType = 0x0005 proposalTypeExternalInit proposalType = 0x0006 proposalTypeGroupContextExtensions proposalType = 0x0007 ) func (t *proposalType) unmarshal(r *Reader) error { v, ok := r.readUint16() if !ok { return errUnexpectedEOF } *t = proposalType(v) switch *t { case proposalTypeAdd, proposalTypeUpdate, proposalTypeRemove, proposalTypePSK, proposalTypeReinit, proposalTypeExternalInit, proposalTypeGroupContextExtensions: return nil default: return errInvalidProposalType } } func (t proposalType) marshal(w *Writer) { w.addUint16(uint16(t)) } // --- Proposal --- type proposal struct { proposalType proposalType add *add update *update remove *remove preSharedKey *preSharedKey reInit *reInit externalInit *externalInit groupContextExtensions *groupContextExtensions } func (prop *proposal) unmarshal(r *Reader) error { *prop = proposal{} if err := prop.proposalType.unmarshal(r); err != nil { return err } switch prop.proposalType { case proposalTypeAdd: prop.add = &add{} return prop.add.unmarshal(r) case proposalTypeUpdate: prop.update = &update{} return prop.update.unmarshal(r) case proposalTypeRemove: prop.remove = &remove{} return prop.remove.unmarshal(r) case proposalTypePSK: prop.preSharedKey = &preSharedKey{} return prop.preSharedKey.unmarshal(r) case proposalTypeReinit: prop.reInit = &reInit{} return prop.reInit.unmarshal(r) case proposalTypeExternalInit: prop.externalInit = &externalInit{} return prop.externalInit.unmarshal(r) case proposalTypeGroupContextExtensions: prop.groupContextExtensions = &groupContextExtensions{} return prop.groupContextExtensions.unmarshal(r) default: panic("unreachable") } } func (prop *proposal) marshal(w *Writer) { prop.proposalType.marshal(w) switch prop.proposalType { case proposalTypeAdd: prop.add.marshal(w) case proposalTypeUpdate: prop.update.marshal(w) case proposalTypeRemove: prop.remove.marshal(w) case proposalTypePSK: prop.preSharedKey.marshal(w) case proposalTypeReinit: prop.reInit.marshal(w) case proposalTypeExternalInit: prop.externalInit.marshal(w) case proposalTypeGroupContextExtensions: prop.groupContextExtensions.marshal(w) default: panic("unreachable") } } // --- Proposal sub-types --- type add struct { keyPackage KeyPackage } func (a *add) unmarshal(r *Reader) error { *a = add{} return a.keyPackage.unmarshal(r) } func (a *add) marshal(w *Writer) { a.keyPackage.marshal(w) } type update struct { leafNode leafNode } func (upd *update) unmarshal(r *Reader) error { *upd = update{} return upd.leafNode.unmarshal(r) } func (upd *update) marshal(w *Writer) { upd.leafNode.marshal(w) } type remove struct { removed leafIndex } func (rm *remove) unmarshal(r *Reader) error { *rm = remove{} v, ok := r.readUint32() if !ok { return errUnexpectedEOF } rm.removed = leafIndex(v) return nil } func (rm *remove) marshal(w *Writer) { w.addUint32(uint32(rm.removed)) } type preSharedKey struct { psk preSharedKeyID } func (psk *preSharedKey) unmarshal(r *Reader) error { *psk = preSharedKey{} return psk.psk.unmarshal(r) } func (psk *preSharedKey) marshal(w *Writer) { psk.psk.marshal(w) } type reInit struct { groupID GroupID version protocolVersion cipherSuite CipherSuite extensions []extension } func (ri *reInit) unmarshal(r *Reader) error { *ri = reInit{} var ok bool ri.groupID, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } v, ok := r.readUint16() if !ok { return errUnexpectedEOF } ri.version = protocolVersion(v) v, ok = r.readUint16() if !ok { return errUnexpectedEOF } ri.cipherSuite = CipherSuite(v) exts, err := unmarshalExtensionVec(r) if err != nil { return err } ri.extensions = exts return nil } func (ri *reInit) marshal(w *Writer) { w.writeOpaqueVec([]byte(ri.groupID)) w.addUint16(uint16(ri.version)) w.addUint16(uint16(ri.cipherSuite)) marshalExtensionVec(w, ri.extensions) } type externalInit struct { kemOutput []byte } func (ei *externalInit) unmarshal(r *Reader) error { *ei = externalInit{} var ok bool ei.kemOutput, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return nil } func (ei *externalInit) marshal(w *Writer) { w.writeOpaqueVec(ei.kemOutput) } type groupContextExtensions struct { extensions []extension } func (exts *groupContextExtensions) unmarshal(r *Reader) error { *exts = groupContextExtensions{} l, err := unmarshalExtensionVec(r) if err != nil { return err } exts.extensions = l return nil } func (exts *groupContextExtensions) marshal(w *Writer) { marshalExtensionVec(w, exts.extensions) } // --- ProposalOrRef --- type proposalOrRefType uint8 const ( proposalOrRefTypeProposal proposalOrRefType = 1 proposalOrRefTypeReference proposalOrRefType = 2 ) func (t *proposalOrRefType) unmarshal(r *Reader) error { b, ok := r.readByte() if !ok { return errUnexpectedEOF } *t = proposalOrRefType(b) switch *t { case proposalOrRefTypeProposal, proposalOrRefTypeReference: return nil default: return errInvalidProposalOrRefType } } func (t proposalOrRefType) marshal(w *Writer) { w.addByte(byte(t)) } type proposalRef []byte func (ref proposalRef) equal(other proposalRef) bool { return bytesEqual([]byte(ref), []byte(other)) } type proposalOrRef struct { typ proposalOrRefType proposal *proposal reference proposalRef } func (por *proposalOrRef) unmarshal(r *Reader) error { *por = proposalOrRef{} if err := por.typ.unmarshal(r); err != nil { return err } switch por.typ { case proposalOrRefTypeProposal: por.proposal = &proposal{} return por.proposal.unmarshal(r) case proposalOrRefTypeReference: var ok bool por.reference, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return nil default: panic("unreachable") } } func (por *proposalOrRef) marshal(w *Writer) { por.typ.marshal(w) switch por.typ { case proposalOrRefTypeProposal: por.proposal.marshal(w) case proposalOrRefTypeReference: w.writeOpaqueVec([]byte(por.reference)) default: panic("unreachable") } } // --- Proposal list validation (RFC 9420 §12.2) --- func verifyProposalList(proposals []proposal, senders []leafIndex, committer leafIndex) error { if len(proposals) != len(senders) { panic("unreachable") } addKeys := map[string]struct{}{} updateOrRemove := map[leafIndex]struct{}{} pskKeys := map[string]struct{}{} hasGroupCtxExts := false for i, prop := range proposals { snd := senders[i] switch prop.proposalType { case proposalTypeAdd: k := string(prop.add.keyPackage.leafNode.signatureKey) if _, dup := addKeys[k]; dup { return errDupAddKey } addKeys[k] = struct{}{} case proposalTypeUpdate: if snd == committer { return errUpdateByCommitter } if _, dup := updateOrRemove[snd]; dup { return errDupUpdateOrRemove } updateOrRemove[snd] = struct{}{} case proposalTypeRemove: if prop.remove.removed == committer { return errRemoveCommitter } if _, dup := updateOrRemove[prop.remove.removed]; dup { return errDupUpdateOrRemove } updateOrRemove[prop.remove.removed] = struct{}{} case proposalTypePSK: raw, err := marshalRaw(&prop.preSharedKey.psk) if err != nil { return err } k := string(raw) if _, dup := pskKeys[k]; dup { return errDupPSK } pskKeys[k] = struct{}{} case proposalTypeGroupContextExtensions: if hasGroupCtxExts { return errDupGroupContextExts } hasGroupCtxExts = true case proposalTypeReinit: if len(proposals) > 1 { return errReinitWithOther } case proposalTypeExternalInit: return errExternalInitNotAllowed } } return nil } func proposalListNeedsPath(proposals []proposal) bool { if len(proposals) == 0 { return true } for _, prop := range proposals { switch prop.proposalType { case proposalTypeUpdate, proposalTypeRemove, proposalTypeExternalInit, proposalTypeGroupContextExtensions: return true } } return false }