package mls // MLS framing types (RFC 9420 §6). // Pure serialization — crypto operations (sign/encrypt/decrypt) are // in framing_crypto.mx once the cipher suite layer is available. import "errors" var ( errInvalidContentType = errors.New("mls: invalid content type") errInvalidSenderType = errors.New("mls: invalid sender type") errInvalidWireFormat = errors.New("mls: invalid wire format") errInvalidVersion = errors.New("mls: invalid protocol version") errNonZeroPadding = errors.New("mls: padding contains non-zero bytes") ) // --- Protocol version --- type protocolVersion uint16 const ( protocolVersionMLS10 protocolVersion = 1 ) // --- Content type --- type contentType uint8 const ( contentTypeApplication contentType = 1 contentTypeProposal contentType = 2 contentTypeCommit contentType = 3 ) func (ct *contentType) unmarshal(r *Reader) error { b, ok := r.readByte() if !ok { return errUnexpectedEOF } *ct = contentType(b) switch *ct { case contentTypeApplication, contentTypeProposal, contentTypeCommit: return nil default: return errInvalidContentType } } func (ct contentType) marshal(w *Writer) { w.addByte(byte(ct)) } // --- Sender type --- type senderType uint8 const ( senderTypeMember senderType = 1 senderTypeExternal senderType = 2 senderTypeNewMemberProposal senderType = 3 senderTypeNewMemberCommit senderType = 4 ) func (st *senderType) unmarshal(r *Reader) error { b, ok := r.readByte() if !ok { return errUnexpectedEOF } *st = senderType(b) switch *st { case senderTypeMember, senderTypeExternal, senderTypeNewMemberProposal, senderTypeNewMemberCommit: return nil default: return errInvalidSenderType } } func (st senderType) marshal(w *Writer) { w.addByte(byte(st)) } // --- Sender --- type sender struct { senderType senderType leafIndex leafIndex // for senderTypeMember senderIndex uint32 // for senderTypeExternal } func (snd *sender) unmarshal(r *Reader) error { *snd = sender{} if err := snd.senderType.unmarshal(r); err != nil { return err } switch snd.senderType { case senderTypeMember: v, ok := r.readUint32() if !ok { return errUnexpectedEOF } snd.leafIndex = leafIndex(v) case senderTypeExternal: v, ok := r.readUint32() if !ok { return errUnexpectedEOF } snd.senderIndex = v } return nil } func (snd *sender) marshal(w *Writer) { snd.senderType.marshal(w) switch snd.senderType { case senderTypeMember: w.addUint32(uint32(snd.leafIndex)) case senderTypeExternal: w.addUint32(snd.senderIndex) } } // --- Wire format --- type wireFormat uint16 // http://www.iana.org/assignments/mls/mls.xhtml#mls-wire-formats const ( wireFormatMLSPublicMessage wireFormat = 0x0001 wireFormatMLSPrivateMessage wireFormat = 0x0002 wireFormatMLSWelcome wireFormat = 0x0003 wireFormatMLSGroupInfo wireFormat = 0x0004 wireFormatMLSKeyPackage wireFormat = 0x0005 ) func (wf *wireFormat) unmarshal(r *Reader) error { v, ok := r.readUint16() if !ok { return errUnexpectedEOF } *wf = wireFormat(v) switch *wf { case wireFormatMLSPublicMessage, wireFormatMLSPrivateMessage, wireFormatMLSWelcome, wireFormatMLSGroupInfo, wireFormatMLSKeyPackage: return nil default: return errInvalidWireFormat } } func (wf wireFormat) marshal(w *Writer) { w.addUint16(uint16(wf)) } // --- GroupID --- // GroupID is an application-specific group identifier. type GroupID []byte func (ref GroupID) equal(other GroupID) bool { if len(ref) != len(other) { return false } for i := range ref { if ref[i] != other[i] { return false } } return true } // --- FramedContent --- type framedContent struct { groupID GroupID epoch uint64 sender sender authenticatedData []byte contentType contentType applicationData []byte // for contentTypeApplication proposal *proposal // for contentTypeProposal commit *commit // for contentTypeCommit } func (content *framedContent) unmarshal(r *Reader) error { *content = framedContent{} var ok bool content.groupID, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } content.epoch, ok = r.readUint64() if !ok { return errUnexpectedEOF } if err := content.sender.unmarshal(r); err != nil { return err } content.authenticatedData, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } if err := content.contentType.unmarshal(r); err != nil { return err } switch content.contentType { case contentTypeApplication: content.applicationData, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return nil case contentTypeProposal: content.proposal = &proposal{} return content.proposal.unmarshal(r) case contentTypeCommit: content.commit = &commit{} return content.commit.unmarshal(r) default: panic("unreachable") } } func (content *framedContent) marshal(w *Writer) { w.writeOpaqueVec([]byte(content.groupID)) w.addUint64(content.epoch) content.sender.marshal(w) w.writeOpaqueVec(content.authenticatedData) content.contentType.marshal(w) switch content.contentType { case contentTypeApplication: w.writeOpaqueVec(content.applicationData) case contentTypeProposal: content.proposal.marshal(w) case contentTypeCommit: content.commit.marshal(w) default: panic("unreachable") } } // --- MLSMessage (top-level wire message) --- type mlsMessage struct { version protocolVersion wireFormat wireFormat publicMessage *publicMessage // for wireFormatMLSPublicMessage privateMessage *privateMessage // for wireFormatMLSPrivateMessage welcome *Welcome // for wireFormatMLSWelcome groupInfo *groupInfo // for wireFormatMLSGroupInfo keyPackage *KeyPackage // for wireFormatMLSKeyPackage } func (msg *mlsMessage) unmarshal(r *Reader) error { *msg = mlsMessage{} v, ok := r.readUint16() if !ok { return errUnexpectedEOF } msg.version = protocolVersion(v) if msg.version != protocolVersionMLS10 { return errInvalidVersion } if err := msg.wireFormat.unmarshal(r); err != nil { return err } switch msg.wireFormat { case wireFormatMLSPublicMessage: msg.publicMessage = &publicMessage{} return msg.publicMessage.unmarshal(r) case wireFormatMLSPrivateMessage: msg.privateMessage = &privateMessage{} return msg.privateMessage.unmarshal(r) case wireFormatMLSWelcome: msg.welcome = &Welcome{} return msg.welcome.unmarshal(r) case wireFormatMLSGroupInfo: msg.groupInfo = &groupInfo{} return msg.groupInfo.unmarshal(r) case wireFormatMLSKeyPackage: msg.keyPackage = &KeyPackage{} return msg.keyPackage.unmarshal(r) default: panic("unreachable") } } func (msg *mlsMessage) marshal(w *Writer) { w.addUint16(uint16(msg.version)) msg.wireFormat.marshal(w) switch msg.wireFormat { case wireFormatMLSPublicMessage: msg.publicMessage.marshal(w) case wireFormatMLSPrivateMessage: msg.privateMessage.marshal(w) case wireFormatMLSWelcome: msg.welcome.marshal(w) case wireFormatMLSGroupInfo: msg.groupInfo.marshal(w) case wireFormatMLSKeyPackage: msg.keyPackage.marshal(w) default: panic("unreachable") } } // --- FramedContentAuthData --- type framedContentAuthData struct { signature []byte confirmationTag []byte // for contentTypeCommit } func (authData *framedContentAuthData) unmarshal(r *Reader, ct contentType) error { *authData = framedContentAuthData{} var ok bool authData.signature, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } if ct == contentTypeCommit { authData.confirmationTag, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } } return nil } func (authData *framedContentAuthData) marshal(w *Writer, ct contentType) { w.writeOpaqueVec(authData.signature) if ct == contentTypeCommit { w.writeOpaqueVec(authData.confirmationTag) } } // --- AuthenticatedContent --- type authenticatedContent struct { wireFormat wireFormat content framedContent auth framedContentAuthData } func (authContent *authenticatedContent) unmarshal(r *Reader) error { if err := authContent.wireFormat.unmarshal(r); err != nil { return err } if err := authContent.content.unmarshal(r); err != nil { return err } return authContent.auth.unmarshal(r, authContent.content.contentType) } func (authContent *authenticatedContent) marshal(w *Writer) { authContent.wireFormat.marshal(w) authContent.content.marshal(w) authContent.auth.marshal(w, authContent.content.contentType) } func (authContent *authenticatedContent) confirmedTranscriptHashInput() *confirmedTranscriptHashInput { return &confirmedTranscriptHashInput{ wireFormat: authContent.wireFormat, content: authContent.content, signature: authContent.auth.signature, } } func (authContent *authenticatedContent) framedContentTBS(ctx *groupContext) *framedContentTBS { return &framedContentTBS{ version: protocolVersionMLS10, wireFormat: authContent.wireFormat, content: authContent.content, context: ctx, } } // --- FramedContentTBS (to-be-signed) --- type framedContentTBS struct { version protocolVersion wireFormat wireFormat content framedContent context *groupContext // for senderTypeMember and senderTypeNewMemberCommit } func (content *framedContentTBS) marshal(w *Writer) { w.addUint16(uint16(content.version)) content.wireFormat.marshal(w) content.content.marshal(w) switch content.content.sender.senderType { case senderTypeMember, senderTypeNewMemberCommit: content.context.marshal(w) } } // --- PublicMessage --- type publicMessage struct { content framedContent auth framedContentAuthData membershipTag []byte // for senderTypeMember } func (msg *publicMessage) unmarshal(r *Reader) error { *msg = publicMessage{} if err := msg.content.unmarshal(r); err != nil { return err } if err := msg.auth.unmarshal(r, msg.content.contentType); err != nil { return err } if msg.content.sender.senderType == senderTypeMember { var ok bool msg.membershipTag, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } } return nil } func (msg *publicMessage) marshal(w *Writer) { msg.content.marshal(w) msg.auth.marshal(w, msg.content.contentType) if msg.content.sender.senderType == senderTypeMember { w.writeOpaqueVec(msg.membershipTag) } } func (msg *publicMessage) authenticatedContent() *authenticatedContent { return &authenticatedContent{ wireFormat: wireFormatMLSPublicMessage, content: msg.content, auth: msg.auth, } } func (msg *publicMessage) authenticatedContentTBM(ctx *groupContext) *authenticatedContentTBM { return &authenticatedContentTBM{ contentTBS: *msg.authenticatedContent().framedContentTBS(ctx), auth: msg.auth, } } // --- AuthenticatedContentTBM (to-be-MACed) --- type authenticatedContentTBM struct { contentTBS framedContentTBS auth framedContentAuthData } func (tbm *authenticatedContentTBM) marshal(w *Writer) { tbm.contentTBS.marshal(w) tbm.auth.marshal(w, tbm.contentTBS.content.contentType) } // --- PrivateMessage --- type privateMessage struct { groupID GroupID epoch uint64 contentType contentType authenticatedData []byte encryptedSenderData []byte ciphertext []byte } func (msg *privateMessage) unmarshal(r *Reader) error { *msg = privateMessage{} var ok bool msg.groupID, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } msg.epoch, ok = r.readUint64() if !ok { return errUnexpectedEOF } if err := msg.contentType.unmarshal(r); err != nil { return err } msg.authenticatedData, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } msg.encryptedSenderData, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } msg.ciphertext, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return nil } func (msg *privateMessage) marshal(w *Writer) { w.writeOpaqueVec([]byte(msg.groupID)) w.addUint64(msg.epoch) msg.contentType.marshal(w) w.writeOpaqueVec(msg.authenticatedData) w.writeOpaqueVec(msg.encryptedSenderData) w.writeOpaqueVec(msg.ciphertext) } func (msg *privateMessage) authenticatedContent(sd *senderData, content *privateMessageContent) *authenticatedContent { return content.authenticatedContent(&framedContent{ groupID: msg.groupID, epoch: msg.epoch, sender: sender{ senderType: senderTypeMember, leafIndex: sd.leafIndex, }, authenticatedData: msg.authenticatedData, contentType: msg.contentType, applicationData: content.applicationData, proposal: content.proposal, commit: content.commit, }) } // --- Sender data AAD --- type senderDataAAD struct { groupID GroupID epoch uint64 contentType contentType } func (aad *senderDataAAD) marshal(w *Writer) { w.writeOpaqueVec([]byte(aad.groupID)) w.addUint64(aad.epoch) aad.contentType.marshal(w) } // --- Private content AAD --- type privateContentAAD struct { groupID GroupID epoch uint64 contentType contentType authenticatedData []byte } func (aad *privateContentAAD) marshal(w *Writer) { w.writeOpaqueVec([]byte(aad.groupID)) w.addUint64(aad.epoch) aad.contentType.marshal(w) w.writeOpaqueVec(aad.authenticatedData) } // --- PrivateMessageContent --- type privateMessageContent struct { applicationData []byte // for contentTypeApplication proposal *proposal // for contentTypeProposal commit *commit // for contentTypeCommit auth framedContentAuthData } func (content *privateMessageContent) unmarshal(r *Reader, ct contentType) error { *content = privateMessageContent{} var err error switch ct { case contentTypeApplication: var ok bool content.applicationData, ok = r.readOpaqueVec() if !ok { err = errUnexpectedEOF } case contentTypeProposal: content.proposal = &proposal{} err = content.proposal.unmarshal(r) case contentTypeCommit: content.commit = &commit{} err = content.commit.unmarshal(r) default: panic("unreachable") } if err != nil { return err } return content.auth.unmarshal(r, ct) } func (content *privateMessageContent) marshal(w *Writer, ct contentType) { switch ct { case contentTypeApplication: w.writeOpaqueVec(content.applicationData) case contentTypeProposal: content.proposal.marshal(w) case contentTypeCommit: content.commit.marshal(w) default: panic("unreachable") } content.auth.marshal(w, ct) } func (content *privateMessageContent) authenticatedContent(fc *framedContent) *authenticatedContent { return &authenticatedContent{ wireFormat: wireFormatMLSPrivateMessage, content: *fc, auth: content.auth, } } // --- SenderData --- type senderData struct { leafIndex leafIndex generation uint32 reuseGuard [4]byte } func (data *senderData) unmarshal(r *Reader) error { v, ok := r.readUint32() if !ok { return errUnexpectedEOF } data.leafIndex = leafIndex(v) data.generation, ok = r.readUint32() if !ok { return errUnexpectedEOF } guard, ok := r.readN(4) if !ok { return errUnexpectedEOF } data.reuseGuard[0] = guard[0] data.reuseGuard[1] = guard[1] data.reuseGuard[2] = guard[2] data.reuseGuard[3] = guard[3] return nil } func (data *senderData) marshal(w *Writer) { w.addUint32(uint32(data.leafIndex)) w.addUint32(data.generation) w.addBytes(data.reuseGuard[:]) }