package mls // MLS key schedule types (RFC 9420 §8). // Data types and serialization — crypto operations go in key_schedule_crypto.mx. import "errors" var ( errInvalidPSKType = errors.New("mls: invalid PSK type") errInvalidPSKUsage = errors.New("mls: invalid resumption PSK usage") errInvalidProposalOrRefType = errors.New("mls: invalid proposal or ref type") ) // --- CipherSuite --- type CipherSuite uint16 const ( // MLS_128_DHKEMP256_AES128GCM_SHA256_P256 CipherSuite0x0001 CipherSuite = 0x0001 // MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 CipherSuite0x0003 CipherSuite = 0x0003 ) // --- GroupContext --- type groupContext struct { version protocolVersion cipherSuite CipherSuite groupID GroupID epoch uint64 treeHash []byte confirmedTranscriptHash []byte extensions []extension } func (ctx *groupContext) unmarshal(r *Reader) error { *ctx = groupContext{} v, ok := r.readUint16() if !ok { return errUnexpectedEOF } ctx.version = protocolVersion(v) if ctx.version != protocolVersionMLS10 { return errInvalidVersion } v, ok = r.readUint16() if !ok { return errUnexpectedEOF } ctx.cipherSuite = CipherSuite(v) ctx.groupID, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } ctx.epoch, ok = r.readUint64() if !ok { return errUnexpectedEOF } ctx.treeHash, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } ctx.confirmedTranscriptHash, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } exts, err := unmarshalExtensionVec(r) if err != nil { return err } ctx.extensions = exts return nil } func (ctx *groupContext) marshal(w *Writer) { w.addUint16(uint16(ctx.version)) w.addUint16(uint16(ctx.cipherSuite)) w.writeOpaqueVec([]byte(ctx.groupID)) w.addUint64(ctx.epoch) w.writeOpaqueVec(ctx.treeHash) w.writeOpaqueVec(ctx.confirmedTranscriptHash) marshalExtensionVec(w, ctx.extensions) } // --- Secret labels --- var ( secretLabelInit = []byte("init") secretLabelSenderData = []byte("sender data") secretLabelEncryption = []byte("encryption") secretLabelExporter = []byte("exporter") secretLabelExternal = []byte("external") secretLabelConfirm = []byte("confirm") secretLabelMembership = []byte("membership") secretLabelResumption = []byte("resumption") secretLabelAuthentication = []byte("authentication") ) // --- ConfirmedTranscriptHashInput --- type confirmedTranscriptHashInput struct { wireFormat wireFormat content framedContent signature []byte } func (input *confirmedTranscriptHashInput) marshal(w *Writer) { input.wireFormat.marshal(w) input.content.marshal(w) w.writeOpaqueVec(input.signature) } // --- PSK types --- type pskType uint8 const ( pskTypeExternal pskType = 1 pskTypeResumption pskType = 2 ) func (t *pskType) unmarshal(r *Reader) error { b, ok := r.readByte() if !ok { return errUnexpectedEOF } *t = pskType(b) switch *t { case pskTypeExternal, pskTypeResumption: return nil default: return errInvalidPSKType } } func (t pskType) marshal(w *Writer) { w.addByte(byte(t)) } type resumptionPSKUsage uint8 const ( resumptionPSKUsageApplication resumptionPSKUsage = 1 resumptionPSKUsageReinit resumptionPSKUsage = 2 resumptionPSKUsageBranch resumptionPSKUsage = 3 ) func (usage *resumptionPSKUsage) unmarshal(r *Reader) error { b, ok := r.readByte() if !ok { return errUnexpectedEOF } *usage = resumptionPSKUsage(b) switch *usage { case resumptionPSKUsageApplication, resumptionPSKUsageReinit, resumptionPSKUsageBranch: return nil default: return errInvalidPSKUsage } } func (usage resumptionPSKUsage) marshal(w *Writer) { w.addByte(byte(usage)) } // --- PreSharedKeyID --- type preSharedKeyID struct { pskType pskType pskID []byte // for pskTypeExternal usage resumptionPSKUsage // for pskTypeResumption pskGroupID GroupID // for pskTypeResumption pskEpoch uint64 // for pskTypeResumption pskNonce []byte } func (id *preSharedKeyID) unmarshal(r *Reader) error { *id = preSharedKeyID{} if err := id.pskType.unmarshal(r); err != nil { return err } switch id.pskType { case pskTypeExternal: var ok bool id.pskID, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } case pskTypeResumption: if err := id.usage.unmarshal(r); err != nil { return err } var ok bool id.pskGroupID, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } id.pskEpoch, ok = r.readUint64() if !ok { return errUnexpectedEOF } default: panic("unreachable") } var ok bool id.pskNonce, ok = r.readOpaqueVec() if !ok { return errUnexpectedEOF } return nil } func (id *preSharedKeyID) marshal(w *Writer) { id.pskType.marshal(w) switch id.pskType { case pskTypeExternal: w.writeOpaqueVec(id.pskID) case pskTypeResumption: id.usage.marshal(w) w.writeOpaqueVec([]byte(id.pskGroupID)) w.addUint64(id.pskEpoch) default: panic("unreachable") } w.writeOpaqueVec(id.pskNonce) }