key_schedule.mx raw
1 package mls
2
3 // MLS key schedule types (RFC 9420 §8).
4 // Data types and serialization — crypto operations go in key_schedule_crypto.mx.
5
6 import "errors"
7
8 var (
9 errInvalidPSKType = errors.New("mls: invalid PSK type")
10 errInvalidPSKUsage = errors.New("mls: invalid resumption PSK usage")
11 errInvalidProposalOrRefType = errors.New("mls: invalid proposal or ref type")
12 )
13
14 // --- CipherSuite ---
15
16 type CipherSuite uint16
17
18 const (
19 // MLS_128_DHKEMP256_AES128GCM_SHA256_P256
20 CipherSuite0x0001 CipherSuite = 0x0001
21 // MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519
22 CipherSuite0x0003 CipherSuite = 0x0003
23 )
24
25 // --- GroupContext ---
26
27 type groupContext struct {
28 version protocolVersion
29 cipherSuite CipherSuite
30 groupID GroupID
31 epoch uint64
32 treeHash []byte
33 confirmedTranscriptHash []byte
34 extensions []extension
35 }
36
37 func (ctx *groupContext) unmarshal(r *Reader) error {
38 *ctx = groupContext{}
39
40 v, ok := r.readUint16()
41 if !ok {
42 return errUnexpectedEOF
43 }
44 ctx.version = protocolVersion(v)
45 if ctx.version != protocolVersionMLS10 {
46 return errInvalidVersion
47 }
48
49 v, ok = r.readUint16()
50 if !ok {
51 return errUnexpectedEOF
52 }
53 ctx.cipherSuite = CipherSuite(v)
54
55 ctx.groupID, ok = r.readOpaqueVec()
56 if !ok {
57 return errUnexpectedEOF
58 }
59 ctx.epoch, ok = r.readUint64()
60 if !ok {
61 return errUnexpectedEOF
62 }
63 ctx.treeHash, ok = r.readOpaqueVec()
64 if !ok {
65 return errUnexpectedEOF
66 }
67 ctx.confirmedTranscriptHash, ok = r.readOpaqueVec()
68 if !ok {
69 return errUnexpectedEOF
70 }
71
72 exts, err := unmarshalExtensionVec(r)
73 if err != nil {
74 return err
75 }
76 ctx.extensions = exts
77 return nil
78 }
79
80 func (ctx *groupContext) marshal(w *Writer) {
81 w.addUint16(uint16(ctx.version))
82 w.addUint16(uint16(ctx.cipherSuite))
83 w.writeOpaqueVec([]byte(ctx.groupID))
84 w.addUint64(ctx.epoch)
85 w.writeOpaqueVec(ctx.treeHash)
86 w.writeOpaqueVec(ctx.confirmedTranscriptHash)
87 marshalExtensionVec(w, ctx.extensions)
88 }
89
90 // --- Secret labels ---
91
92 var (
93 secretLabelInit = []byte("init")
94 secretLabelSenderData = []byte("sender data")
95 secretLabelEncryption = []byte("encryption")
96 secretLabelExporter = []byte("exporter")
97 secretLabelExternal = []byte("external")
98 secretLabelConfirm = []byte("confirm")
99 secretLabelMembership = []byte("membership")
100 secretLabelResumption = []byte("resumption")
101 secretLabelAuthentication = []byte("authentication")
102 )
103
104 // --- ConfirmedTranscriptHashInput ---
105
106 type confirmedTranscriptHashInput struct {
107 wireFormat wireFormat
108 content framedContent
109 signature []byte
110 }
111
112 func (input *confirmedTranscriptHashInput) marshal(w *Writer) {
113 input.wireFormat.marshal(w)
114 input.content.marshal(w)
115 w.writeOpaqueVec(input.signature)
116 }
117
118 // --- PSK types ---
119
120 type pskType uint8
121
122 const (
123 pskTypeExternal pskType = 1
124 pskTypeResumption pskType = 2
125 )
126
127 func (t *pskType) unmarshal(r *Reader) error {
128 b, ok := r.readByte()
129 if !ok {
130 return errUnexpectedEOF
131 }
132 *t = pskType(b)
133 switch *t {
134 case pskTypeExternal, pskTypeResumption:
135 return nil
136 default:
137 return errInvalidPSKType
138 }
139 }
140
141 func (t pskType) marshal(w *Writer) {
142 w.addByte(byte(t))
143 }
144
145 type resumptionPSKUsage uint8
146
147 const (
148 resumptionPSKUsageApplication resumptionPSKUsage = 1
149 resumptionPSKUsageReinit resumptionPSKUsage = 2
150 resumptionPSKUsageBranch resumptionPSKUsage = 3
151 )
152
153 func (usage *resumptionPSKUsage) unmarshal(r *Reader) error {
154 b, ok := r.readByte()
155 if !ok {
156 return errUnexpectedEOF
157 }
158 *usage = resumptionPSKUsage(b)
159 switch *usage {
160 case resumptionPSKUsageApplication, resumptionPSKUsageReinit, resumptionPSKUsageBranch:
161 return nil
162 default:
163 return errInvalidPSKUsage
164 }
165 }
166
167 func (usage resumptionPSKUsage) marshal(w *Writer) {
168 w.addByte(byte(usage))
169 }
170
171 // --- PreSharedKeyID ---
172
173 type preSharedKeyID struct {
174 pskType pskType
175
176 pskID []byte // for pskTypeExternal
177
178 usage resumptionPSKUsage // for pskTypeResumption
179 pskGroupID GroupID // for pskTypeResumption
180 pskEpoch uint64 // for pskTypeResumption
181
182 pskNonce []byte
183 }
184
185 func (id *preSharedKeyID) unmarshal(r *Reader) error {
186 *id = preSharedKeyID{}
187 if err := id.pskType.unmarshal(r); err != nil {
188 return err
189 }
190
191 switch id.pskType {
192 case pskTypeExternal:
193 var ok bool
194 id.pskID, ok = r.readOpaqueVec()
195 if !ok {
196 return errUnexpectedEOF
197 }
198 case pskTypeResumption:
199 if err := id.usage.unmarshal(r); err != nil {
200 return err
201 }
202 var ok bool
203 id.pskGroupID, ok = r.readOpaqueVec()
204 if !ok {
205 return errUnexpectedEOF
206 }
207 id.pskEpoch, ok = r.readUint64()
208 if !ok {
209 return errUnexpectedEOF
210 }
211 default:
212 panic("unreachable")
213 }
214
215 var ok bool
216 id.pskNonce, ok = r.readOpaqueVec()
217 if !ok {
218 return errUnexpectedEOF
219 }
220 return nil
221 }
222
223 func (id *preSharedKeyID) marshal(w *Writer) {
224 id.pskType.marshal(w)
225 switch id.pskType {
226 case pskTypeExternal:
227 w.writeOpaqueVec(id.pskID)
228 case pskTypeResumption:
229 id.usage.marshal(w)
230 w.writeOpaqueVec([]byte(id.pskGroupID))
231 w.addUint64(id.pskEpoch)
232 default:
233 panic("unreachable")
234 }
235 w.writeOpaqueVec(id.pskNonce)
236 }
237