proposal.mx raw
1 package mls
2
3 // MLS proposals (RFC 9420 §12).
4
5 import "errors"
6
7 var (
8 errInvalidProposalType = errors.New("mls: invalid proposal type")
9 errDupAddKey = errors.New("mls: duplicate add proposal signature key")
10 errUpdateByCommitter = errors.New("mls: update proposal from committer")
11 errDupUpdateOrRemove = errors.New("mls: duplicate update/remove on same leaf")
12 errRemoveCommitter = errors.New("mls: remove proposal targets committer")
13 errDupPSK = errors.New("mls: duplicate PSK proposal")
14 errDupGroupContextExts = errors.New("mls: duplicate group context extensions proposal")
15 errReinitWithOther = errors.New("mls: reinit with other proposals")
16 errExternalInitNotAllowed = errors.New("mls: external init not allowed")
17 )
18
19 // --- ProposalType ---
20
21 type proposalType uint16
22
23 const (
24 proposalTypeAdd proposalType = 0x0001
25 proposalTypeUpdate proposalType = 0x0002
26 proposalTypeRemove proposalType = 0x0003
27 proposalTypePSK proposalType = 0x0004
28 proposalTypeReinit proposalType = 0x0005
29 proposalTypeExternalInit proposalType = 0x0006
30 proposalTypeGroupContextExtensions proposalType = 0x0007
31 )
32
33 func (t *proposalType) unmarshal(r *Reader) error {
34 v, ok := r.readUint16()
35 if !ok {
36 return errUnexpectedEOF
37 }
38 *t = proposalType(v)
39 switch *t {
40 case proposalTypeAdd, proposalTypeUpdate, proposalTypeRemove,
41 proposalTypePSK, proposalTypeReinit, proposalTypeExternalInit,
42 proposalTypeGroupContextExtensions:
43 return nil
44 default:
45 return errInvalidProposalType
46 }
47 }
48
49 func (t proposalType) marshal(w *Writer) {
50 w.addUint16(uint16(t))
51 }
52
53 // --- Proposal ---
54
55 type proposal struct {
56 proposalType proposalType
57 add *add
58 update *update
59 remove *remove
60 preSharedKey *preSharedKey
61 reInit *reInit
62 externalInit *externalInit
63 groupContextExtensions *groupContextExtensions
64 }
65
66 func (prop *proposal) unmarshal(r *Reader) error {
67 *prop = proposal{}
68 if err := prop.proposalType.unmarshal(r); err != nil {
69 return err
70 }
71 switch prop.proposalType {
72 case proposalTypeAdd:
73 prop.add = &add{}
74 return prop.add.unmarshal(r)
75 case proposalTypeUpdate:
76 prop.update = &update{}
77 return prop.update.unmarshal(r)
78 case proposalTypeRemove:
79 prop.remove = &remove{}
80 return prop.remove.unmarshal(r)
81 case proposalTypePSK:
82 prop.preSharedKey = &preSharedKey{}
83 return prop.preSharedKey.unmarshal(r)
84 case proposalTypeReinit:
85 prop.reInit = &reInit{}
86 return prop.reInit.unmarshal(r)
87 case proposalTypeExternalInit:
88 prop.externalInit = &externalInit{}
89 return prop.externalInit.unmarshal(r)
90 case proposalTypeGroupContextExtensions:
91 prop.groupContextExtensions = &groupContextExtensions{}
92 return prop.groupContextExtensions.unmarshal(r)
93 default:
94 panic("unreachable")
95 }
96 }
97
98 func (prop *proposal) marshal(w *Writer) {
99 prop.proposalType.marshal(w)
100 switch prop.proposalType {
101 case proposalTypeAdd:
102 prop.add.marshal(w)
103 case proposalTypeUpdate:
104 prop.update.marshal(w)
105 case proposalTypeRemove:
106 prop.remove.marshal(w)
107 case proposalTypePSK:
108 prop.preSharedKey.marshal(w)
109 case proposalTypeReinit:
110 prop.reInit.marshal(w)
111 case proposalTypeExternalInit:
112 prop.externalInit.marshal(w)
113 case proposalTypeGroupContextExtensions:
114 prop.groupContextExtensions.marshal(w)
115 default:
116 panic("unreachable")
117 }
118 }
119
120 // --- Proposal sub-types ---
121
122 type add struct {
123 keyPackage KeyPackage
124 }
125
126 func (a *add) unmarshal(r *Reader) error {
127 *a = add{}
128 return a.keyPackage.unmarshal(r)
129 }
130
131 func (a *add) marshal(w *Writer) {
132 a.keyPackage.marshal(w)
133 }
134
135 type update struct {
136 leafNode leafNode
137 }
138
139 func (upd *update) unmarshal(r *Reader) error {
140 *upd = update{}
141 return upd.leafNode.unmarshal(r)
142 }
143
144 func (upd *update) marshal(w *Writer) {
145 upd.leafNode.marshal(w)
146 }
147
148 type remove struct {
149 removed leafIndex
150 }
151
152 func (rm *remove) unmarshal(r *Reader) error {
153 *rm = remove{}
154 v, ok := r.readUint32()
155 if !ok {
156 return errUnexpectedEOF
157 }
158 rm.removed = leafIndex(v)
159 return nil
160 }
161
162 func (rm *remove) marshal(w *Writer) {
163 w.addUint32(uint32(rm.removed))
164 }
165
166 type preSharedKey struct {
167 psk preSharedKeyID
168 }
169
170 func (psk *preSharedKey) unmarshal(r *Reader) error {
171 *psk = preSharedKey{}
172 return psk.psk.unmarshal(r)
173 }
174
175 func (psk *preSharedKey) marshal(w *Writer) {
176 psk.psk.marshal(w)
177 }
178
179 type reInit struct {
180 groupID GroupID
181 version protocolVersion
182 cipherSuite CipherSuite
183 extensions []extension
184 }
185
186 func (ri *reInit) unmarshal(r *Reader) error {
187 *ri = reInit{}
188 var ok bool
189 ri.groupID, ok = r.readOpaqueVec()
190 if !ok {
191 return errUnexpectedEOF
192 }
193 v, ok := r.readUint16()
194 if !ok {
195 return errUnexpectedEOF
196 }
197 ri.version = protocolVersion(v)
198 v, ok = r.readUint16()
199 if !ok {
200 return errUnexpectedEOF
201 }
202 ri.cipherSuite = CipherSuite(v)
203 exts, err := unmarshalExtensionVec(r)
204 if err != nil {
205 return err
206 }
207 ri.extensions = exts
208 return nil
209 }
210
211 func (ri *reInit) marshal(w *Writer) {
212 w.writeOpaqueVec([]byte(ri.groupID))
213 w.addUint16(uint16(ri.version))
214 w.addUint16(uint16(ri.cipherSuite))
215 marshalExtensionVec(w, ri.extensions)
216 }
217
218 type externalInit struct {
219 kemOutput []byte
220 }
221
222 func (ei *externalInit) unmarshal(r *Reader) error {
223 *ei = externalInit{}
224 var ok bool
225 ei.kemOutput, ok = r.readOpaqueVec()
226 if !ok {
227 return errUnexpectedEOF
228 }
229 return nil
230 }
231
232 func (ei *externalInit) marshal(w *Writer) {
233 w.writeOpaqueVec(ei.kemOutput)
234 }
235
236 type groupContextExtensions struct {
237 extensions []extension
238 }
239
240 func (exts *groupContextExtensions) unmarshal(r *Reader) error {
241 *exts = groupContextExtensions{}
242 l, err := unmarshalExtensionVec(r)
243 if err != nil {
244 return err
245 }
246 exts.extensions = l
247 return nil
248 }
249
250 func (exts *groupContextExtensions) marshal(w *Writer) {
251 marshalExtensionVec(w, exts.extensions)
252 }
253
254 // --- ProposalOrRef ---
255
256 type proposalOrRefType uint8
257
258 const (
259 proposalOrRefTypeProposal proposalOrRefType = 1
260 proposalOrRefTypeReference proposalOrRefType = 2
261 )
262
263 func (t *proposalOrRefType) unmarshal(r *Reader) error {
264 b, ok := r.readByte()
265 if !ok {
266 return errUnexpectedEOF
267 }
268 *t = proposalOrRefType(b)
269 switch *t {
270 case proposalOrRefTypeProposal, proposalOrRefTypeReference:
271 return nil
272 default:
273 return errInvalidProposalOrRefType
274 }
275 }
276
277 func (t proposalOrRefType) marshal(w *Writer) {
278 w.addByte(byte(t))
279 }
280
281 type proposalRef []byte
282
283 func (ref proposalRef) equal(other proposalRef) bool {
284 return bytesEqual([]byte(ref), []byte(other))
285 }
286
287 type proposalOrRef struct {
288 typ proposalOrRefType
289 proposal *proposal
290 reference proposalRef
291 }
292
293 func (por *proposalOrRef) unmarshal(r *Reader) error {
294 *por = proposalOrRef{}
295 if err := por.typ.unmarshal(r); err != nil {
296 return err
297 }
298 switch por.typ {
299 case proposalOrRefTypeProposal:
300 por.proposal = &proposal{}
301 return por.proposal.unmarshal(r)
302 case proposalOrRefTypeReference:
303 var ok bool
304 por.reference, ok = r.readOpaqueVec()
305 if !ok {
306 return errUnexpectedEOF
307 }
308 return nil
309 default:
310 panic("unreachable")
311 }
312 }
313
314 func (por *proposalOrRef) marshal(w *Writer) {
315 por.typ.marshal(w)
316 switch por.typ {
317 case proposalOrRefTypeProposal:
318 por.proposal.marshal(w)
319 case proposalOrRefTypeReference:
320 w.writeOpaqueVec([]byte(por.reference))
321 default:
322 panic("unreachable")
323 }
324 }
325
326 // --- Proposal list validation (RFC 9420 §12.2) ---
327
328 func verifyProposalList(proposals []proposal, senders []leafIndex, committer leafIndex) error {
329 if len(proposals) != len(senders) {
330 panic("unreachable")
331 }
332
333 addKeys := map[string]struct{}{}
334 updateOrRemove := map[leafIndex]struct{}{}
335 pskKeys := map[string]struct{}{}
336 hasGroupCtxExts := false
337
338 for i, prop := range proposals {
339 snd := senders[i]
340 switch prop.proposalType {
341 case proposalTypeAdd:
342 k := string(prop.add.keyPackage.leafNode.signatureKey)
343 if _, dup := addKeys[k]; dup {
344 return errDupAddKey
345 }
346 addKeys[k] = struct{}{}
347 case proposalTypeUpdate:
348 if snd == committer {
349 return errUpdateByCommitter
350 }
351 if _, dup := updateOrRemove[snd]; dup {
352 return errDupUpdateOrRemove
353 }
354 updateOrRemove[snd] = struct{}{}
355 case proposalTypeRemove:
356 if prop.remove.removed == committer {
357 return errRemoveCommitter
358 }
359 if _, dup := updateOrRemove[prop.remove.removed]; dup {
360 return errDupUpdateOrRemove
361 }
362 updateOrRemove[prop.remove.removed] = struct{}{}
363 case proposalTypePSK:
364 raw, err := marshalRaw(&prop.preSharedKey.psk)
365 if err != nil {
366 return err
367 }
368 k := string(raw)
369 if _, dup := pskKeys[k]; dup {
370 return errDupPSK
371 }
372 pskKeys[k] = struct{}{}
373 case proposalTypeGroupContextExtensions:
374 if hasGroupCtxExts {
375 return errDupGroupContextExts
376 }
377 hasGroupCtxExts = true
378 case proposalTypeReinit:
379 if len(proposals) > 1 {
380 return errReinitWithOther
381 }
382 case proposalTypeExternalInit:
383 return errExternalInitNotAllowed
384 }
385 }
386 return nil
387 }
388
389 func proposalListNeedsPath(proposals []proposal) bool {
390 if len(proposals) == 0 {
391 return true
392 }
393 for _, prop := range proposals {
394 switch prop.proposalType {
395 case proposalTypeUpdate, proposalTypeRemove,
396 proposalTypeExternalInit, proposalTypeGroupContextExtensions:
397 return true
398 }
399 }
400 return false
401 }
402