group.go raw
1 package mls
2
3 import (
4 "crypto/rand"
5 "fmt"
6 "io"
7 "time"
8
9 "golang.org/x/crypto/cryptobyte"
10 )
11
12 type pendingProposal struct {
13 ref proposalRef
14 proposal *proposal
15 sender leafIndex
16 }
17
18 // A Group is a high-level API for an MLS group.
19 type Group struct {
20 tree ratchetTree
21 groupContext groupContext
22
23 interimTranscriptHash []byte
24 pskSecret []byte
25 epochSecret []byte
26 initSecret []byte
27
28 myLeafIndex leafIndex
29 privTree []hpkePrivateKey
30 signaturePriv signaturePrivateKey
31
32 pendingProposals []pendingProposal
33 }
34
35 // Epoch returns the current MLS epoch number. The epoch increments on every
36 // Commit (including application message encrypt/decrypt that triggers ratcheting).
37 func (group *Group) Epoch() uint64 {
38 return group.groupContext.epoch
39 }
40
41 // ExporterSecret derives the exporter secret from the current epoch secret.
42 // This is needed by NIP-EE to derive the NIP-44 conversation key for
43 // encrypting kind 445 group message content.
44 func (group *Group) ExporterSecret() ([]byte, error) {
45 return group.groupContext.cipherSuite.deriveSecret(
46 group.epochSecret, secretLabelExporter,
47 )
48 }
49
50 // GroupContextExtensions returns the extensions from the group context.
51 // Use this to extract application-specific data like NostrGroupData (0xf2ee).
52 func (group *Group) GroupContextExtensions() []extension {
53 return group.groupContext.extensions
54 }
55
56 // FindGroupContextExtension returns the data for the extension with the given
57 // type, or nil if not found.
58 func (group *Group) FindGroupContextExtension(t extensionType) []byte {
59 return findExtensionData(group.groupContext.extensions, t)
60 }
61
62 // Marshal serializes the full Group state (including private keys and epoch
63 // secrets) so it can be persisted and restored later. This is NOT a wire
64 // format — it's for local storage only. The output contains sensitive key
65 // material and must be encrypted at rest.
66 func (group *Group) Marshal() ([]byte, error) {
67 var b cryptobyte.Builder
68
69 // 1. groupContext (TLS-serialized)
70 group.groupContext.marshal(&b)
71
72 // 2. ratchetTree
73 group.tree.marshal(&b)
74
75 // 3. Secrets
76 writeOpaqueVec(&b, group.interimTranscriptHash)
77 writeOpaqueVec(&b, group.pskSecret)
78 writeOpaqueVec(&b, group.epochSecret)
79 writeOpaqueVec(&b, group.initSecret)
80
81 // 4. My identity within the group
82 b.AddUint32(uint32(group.myLeafIndex))
83 writeOpaqueVec(&b, []byte(group.signaturePriv))
84
85 // 5. Private tree (HPKE private keys, indexed by node position)
86 writeVector(&b, len(group.privTree), func(b *cryptobyte.Builder, i int) {
87 writeOpaqueVec(b, []byte(group.privTree[i]))
88 })
89
90 return b.Bytes()
91 }
92
93 // UnmarshalGroup restores a Group from bytes produced by Marshal.
94 func UnmarshalGroup(raw []byte) (*Group, error) {
95 s := cryptobyte.String(raw)
96 g := &Group{}
97
98 // 1. groupContext
99 if err := g.groupContext.unmarshal(&s); err != nil {
100 return nil, fmt.Errorf("mls: unmarshal group context: %w", err)
101 }
102
103 // 2. ratchetTree
104 if err := g.tree.unmarshal(&s); err != nil {
105 return nil, fmt.Errorf("mls: unmarshal ratchet tree: %w", err)
106 }
107
108 // 3. Secrets
109 if !readOpaqueVec(&s, &g.interimTranscriptHash) ||
110 !readOpaqueVec(&s, &g.pskSecret) ||
111 !readOpaqueVec(&s, &g.epochSecret) ||
112 !readOpaqueVec(&s, &g.initSecret) {
113 return nil, fmt.Errorf("mls: unmarshal secrets: unexpected EOF")
114 }
115
116 // 4. My identity
117 if !s.ReadUint32((*uint32)(&g.myLeafIndex)) {
118 return nil, fmt.Errorf("mls: unmarshal leaf index: unexpected EOF")
119 }
120 var sigPriv []byte
121 if !readOpaqueVec(&s, &sigPriv) {
122 return nil, fmt.Errorf("mls: unmarshal signature priv: unexpected EOF")
123 }
124 g.signaturePriv = signaturePrivateKey(sigPriv)
125
126 // 5. Private tree
127 if err := readVector(&s, func(s *cryptobyte.String) error {
128 var k []byte
129 if !readOpaqueVec(s, &k) {
130 return io.ErrUnexpectedEOF
131 }
132 g.privTree = append(g.privTree, hpkePrivateKey(k))
133 return nil
134 }); err != nil {
135 return nil, fmt.Errorf("mls: unmarshal priv tree: %w", err)
136 }
137
138 return g, nil
139 }
140
141 // GroupID returns the MLS group ID.
142 func (group *Group) GroupID() GroupID {
143 return group.groupContext.groupID
144 }
145
146 // DeriveExporter exports keying material from the group's exporter secret
147 // using the MLS exporter derivation (RFC 9420 Section 8).
148 func (group *Group) DeriveExporter(label, context []byte, length uint16) ([]byte, error) {
149 exporterSecret, err := group.ExporterSecret()
150 if err != nil {
151 return nil, err
152 }
153 return deriveExporter(group.groupContext.cipherSuite, exporterSecret, label, context, length)
154 }
155
156 // GroupOptions configures group creation.
157 type GroupOptions struct {
158 // Extensions are included in the group context. For Marmot, this
159 // should include a NostrGroupData extension (0xf2ee).
160 Extensions []extension
161 }
162
163 // CreateGroup creates a new group with a single member.
164 func CreateGroup(groupID GroupID, keyPairPkg *KeyPairPackage) (*Group, error) {
165 return CreateGroupWithOptions(groupID, keyPairPkg, nil)
166 }
167
168 // CreateGroupWithOptions creates a new group with custom group context extensions.
169 func CreateGroupWithOptions(groupID GroupID, keyPairPkg *KeyPairPackage, opts *GroupOptions) (*Group, error) {
170 cs := keyPairPkg.Public.cipherSuite
171
172 tree := make(ratchetTree, 1)
173 tree.add(&keyPairPkg.Public.leafNode)
174
175 privTree := make([]hpkePrivateKey, len(tree))
176 privTree[0] = keyPairPkg.Private.EncryptionKey
177
178 treeHash, err := tree.computeRootTreeHash(cs)
179 if err != nil {
180 return nil, fmt.Errorf("failed to compute root tree hash: %v", err)
181 }
182
183 confirmedTranscriptHash := make([]byte, cs.hash().Size())
184
185 _, kdf, _ := cs.hpke().Params()
186 epochSecret := make([]byte, kdf.ExtractSize())
187 if _, err := rand.Read(epochSecret); err != nil {
188 return nil, fmt.Errorf("failed to generate epoch secret: %v", err)
189 }
190
191 var ctxExts []extension
192 if opts != nil {
193 ctxExts = opts.Extensions
194 }
195
196 groupCtx := groupContext{
197 version: keyPairPkg.Public.version,
198 cipherSuite: keyPairPkg.Public.cipherSuite,
199 groupID: groupID,
200 epoch: 0,
201 treeHash: treeHash,
202 confirmedTranscriptHash: confirmedTranscriptHash,
203 extensions: ctxExts,
204 }
205
206 confirmationTag, err := groupCtx.signConfirmationTag(epochSecret)
207 if err != nil {
208 return nil, fmt.Errorf("failed to sign confirmation tag: %v", err)
209 }
210
211 interimTranscriptHash, err := nextInterimTranscriptHash(cs, confirmedTranscriptHash, confirmationTag)
212 if err != nil {
213 return nil, fmt.Errorf("failed to compute initial interim transcript hash: %v", err)
214 }
215
216 pskSecret, err := extractPSKSecret(cs, nil, nil)
217 if err != nil {
218 return nil, fmt.Errorf("failed to extract PSK secret: %v", err)
219 }
220
221 initSecret, err := groupCtx.cipherSuite.deriveSecret(epochSecret, secretLabelInit)
222 if err != nil {
223 return nil, fmt.Errorf("failed to derive init secret: %v", err)
224 }
225
226 return &Group{
227 tree: tree,
228 privTree: privTree,
229 myLeafIndex: 0,
230 signaturePriv: keyPairPkg.Private.SignatureKey,
231 groupContext: groupCtx,
232 interimTranscriptHash: interimTranscriptHash,
233 pskSecret: pskSecret,
234 epochSecret: epochSecret,
235 initSecret: initSecret,
236 }, nil
237 }
238
239 // GroupFromWelcome creates a new group from a welcome message.
240 func GroupFromWelcome(welcome *Welcome, keyPairPkg *KeyPairPackage) (*Group, error) {
241 keyPkgRef, err := keyPairPkg.Public.GenerateRef()
242 if err != nil {
243 return nil, fmt.Errorf("failed to generate key package ref: %v", err)
244 }
245
246 groupSecrets, err := welcome.decryptGroupSecrets(keyPkgRef, keyPairPkg.Private.InitKey)
247 if err != nil {
248 return nil, fmt.Errorf("failed to decrypt group secrets: %v", err)
249 }
250
251 if !groupSecrets.verifySingleReinitOrBranchPSK() {
252 return nil, fmt.Errorf("mls: more than one key has usage reinit or branch in group secrets")
253 }
254
255 if len(groupSecrets.psks) != 0 {
256 return nil, fmt.Errorf("mls: group secret PSKs are not yet supported")
257 }
258
259 return groupFromSecrets(welcome, keyPairPkg, groupSecrets, nil)
260 }
261
262 type groupFromSecretsOptions struct {
263 rawTree []byte
264 psks [][]byte
265 now func() time.Time
266 }
267
268 func groupFromSecrets(welcome *Welcome, keyPairPkg *KeyPairPackage, groupSecrets *groupSecrets, options *groupFromSecretsOptions) (*Group, error) {
269 if options == nil {
270 options = new(groupFromSecretsOptions)
271 }
272
273 pskSecret, err := extractPSKSecret(welcome.cipherSuite, groupSecrets.psks, options.psks)
274 if err != nil {
275 return nil, fmt.Errorf("failed to extract PSK secret: %v", err)
276 }
277
278 groupInfo, err := welcome.decryptGroupInfo(groupSecrets.joinerSecret, pskSecret)
279 if err != nil {
280 return nil, fmt.Errorf("failed to decrypt group info: %v", err)
281 }
282
283 rawTree := options.rawTree
284 if rawTree == nil {
285 rawTree = findExtensionData(groupInfo.extensions, extensionTypeRatchetTree)
286 }
287 if rawTree == nil {
288 return nil, fmt.Errorf("mls: missing ratchet tree")
289 }
290
291 var tree ratchetTree
292 if err := unmarshal(rawTree, &tree); err != nil {
293 return nil, fmt.Errorf("failed to unmarshal ratchet tree: %v", err)
294 }
295
296 signerNode := tree.getLeaf(groupInfo.signer)
297 if signerNode == nil {
298 return nil, fmt.Errorf("mls: signer node is blank")
299 } else if !groupInfo.verifySignature(signerNode.signatureKey) {
300 return nil, fmt.Errorf("mls: failed to verify signer node signature")
301 }
302 if !groupInfo.verifyConfirmationTag(groupSecrets.joinerSecret, pskSecret) {
303 return nil, fmt.Errorf("mls: failed to verify confirmation tag")
304 }
305 if groupInfo.groupContext.cipherSuite != welcome.cipherSuite {
306 return nil, fmt.Errorf("mls: group info cipher suite doesn't match key package")
307 }
308
309 if err := tree.verifyIntegrity(&groupInfo.groupContext, options.now); err != nil {
310 return nil, fmt.Errorf("failed to verify ratchet tree integrity: %v", err)
311 }
312
313 // TODO: perform other group info verification steps
314
315 groupCtx := groupInfo.groupContext
316
317 epochSecret, err := groupCtx.extractEpochSecret(groupSecrets.joinerSecret, pskSecret)
318 if err != nil {
319 return nil, fmt.Errorf("failed to extract epoch secret: %v", err)
320 }
321
322 initSecret, err := groupCtx.cipherSuite.deriveSecret(epochSecret, secretLabelInit)
323 if err != nil {
324 return nil, fmt.Errorf("failed to derive init secret: %v", err)
325 }
326
327 interimTranscriptHash, err := nextInterimTranscriptHash(groupCtx.cipherSuite, groupCtx.confirmedTranscriptHash, groupInfo.confirmationTag)
328 if err != nil {
329 return nil, fmt.Errorf("failed to compute next interim transcript hash: %v", err)
330 }
331
332 myLeafIndex, ok := tree.findLeaf(&keyPairPkg.Public.leafNode)
333 if !ok {
334 return nil, fmt.Errorf("mls: failed to find my leaf node in ratchet tree")
335 }
336
337 privTree := make([]hpkePrivateKey, len(tree))
338 privTree[int(myLeafIndex.nodeIndex())] = keyPairPkg.Private.EncryptionKey
339
340 if groupSecrets.pathSecret != nil {
341 nodeIndex := commonAncestor(myLeafIndex.nodeIndex(), groupInfo.signer.nodeIndex())
342 err := processPathSecret(groupCtx.cipherSuite, tree, privTree, groupSecrets.pathSecret, nodeIndex)
343 if err != nil {
344 return nil, fmt.Errorf("failed to process path secret: %v", err)
345 }
346 }
347
348 return &Group{
349 tree: tree,
350 groupContext: groupCtx,
351 interimTranscriptHash: interimTranscriptHash,
352 pskSecret: pskSecret,
353 epochSecret: epochSecret,
354 initSecret: initSecret,
355 myLeafIndex: myLeafIndex,
356 privTree: privTree,
357 signaturePriv: keyPairPkg.Private.SignatureKey,
358 }, nil
359 }
360
361 func processPathSecret(cs CipherSuite, tree ratchetTree, privTree []hpkePrivateKey, pathSecret []byte, nodeIndex nodeIndex) error {
362 nodePriv, err := nodePrivFromPathSecret(cs, pathSecret, tree.get(nodeIndex).encryptionKey())
363 if err != nil {
364 return fmt.Errorf("failed to derive node %v private key from path secret: %v", nodeIndex, err)
365 }
366 privTree[int(nodeIndex)] = nodePriv
367
368 for {
369 var ok bool
370 nodeIndex, ok = tree.numLeaves().parent(nodeIndex)
371 if !ok {
372 break
373 }
374
375 pathSecret, err := cs.deriveSecret(pathSecret, []byte("path"))
376 if err != nil {
377 return fmt.Errorf("failed to derive path secret: %v", err)
378 }
379
380 nodePriv, err := nodePrivFromPathSecret(cs, pathSecret, tree.get(nodeIndex).encryptionKey())
381 if err != nil {
382 return fmt.Errorf("failed to derive node %v private key from path secret: %v", nodeIndex, err)
383 }
384 privTree[int(nodeIndex)] = nodePriv
385 }
386
387 return nil
388 }
389
390 // UnmarshalAndProcessMessage decodes a raw MLS message intended for the group
391 // and processes it.
392 //
393 // If the MLS message contains encrypted application data, the decrypted data
394 // is returned.
395 func (group *Group) UnmarshalAndProcessMessage(raw []byte) (plaintext []byte, selfSent bool, err error) {
396 var msg mlsMessage
397 if err := unmarshal([]byte(raw), &msg); err != nil {
398 return nil, false, fmt.Errorf("failed to unmarshal MLS message: %v", err)
399 }
400
401 switch msg.wireFormat {
402 case wireFormatMLSPublicMessage:
403 return nil, false, group.processPublicMessage(msg.publicMessage)
404 case wireFormatMLSPrivateMessage:
405 return group.processPrivateMessage(msg.privateMessage)
406 default:
407 // TODO: support other wire formats
408 return nil, false, fmt.Errorf("mls: unsupported wire format: %v", msg.wireFormat)
409 }
410 }
411
412 func (group *Group) processPublicMessage(pubMsg *publicMessage) error {
413 authContent, err := group.verifyPublicMessage(pubMsg)
414 if err != nil {
415 return fmt.Errorf("failed to verify public message: %v", err)
416 }
417
418 switch authContent.content.contentType {
419 case contentTypeProposal:
420 return group.processProposal(authContent)
421 case contentTypeCommit:
422 return group.processCommit(authContent, nil, nil, nil)
423 case contentTypeApplication:
424 return fmt.Errorf("mls: application content type must be encrypted")
425 default:
426 // TODO: support other content types
427 return fmt.Errorf("mls: unsupported content type: %v", authContent.content.contentType)
428 }
429 }
430
431 func (group *Group) verifyPublicMessage(pubMsg *publicMessage) (*authenticatedContent, error) {
432 if !pubMsg.content.groupID.Equal(group.groupContext.groupID) {
433 return nil, fmt.Errorf("mls: message group ID mismatch")
434 }
435 if pubMsg.content.epoch != group.groupContext.epoch {
436 return nil, fmt.Errorf("mls: epoch mismatch: got %v, want %v", pubMsg.content.epoch, group.groupContext.epoch)
437 }
438
439 if pubMsg.content.sender.senderType != senderTypeMember {
440 // TODO: support other sender types
441 return nil, fmt.Errorf("mls: unsupported sender type: %v", pubMsg.content.sender.senderType)
442 }
443 senderLeafIndex := pubMsg.content.sender.leafIndex
444 // TODO: check tree length
445 senderNode := group.tree.getLeaf(senderLeafIndex)
446 if senderNode == nil {
447 return nil, fmt.Errorf("mls: blank leaf node for sender")
448 }
449
450 authContent := pubMsg.authenticatedContent()
451 if !authContent.verifySignature([]byte(senderNode.signatureKey), &group.groupContext) {
452 return nil, fmt.Errorf("mls: failed to verify public message signature")
453 }
454
455 membershipKey, err := group.groupContext.cipherSuite.deriveSecret(group.epochSecret, secretLabelMembership)
456 if err != nil {
457 return nil, fmt.Errorf("failed to derive membership key: %v", err)
458 } else if !pubMsg.verifyMembershipTag(membershipKey, &group.groupContext) {
459 return nil, fmt.Errorf("failed to verify membership tag")
460 }
461
462 return authContent, nil
463 }
464
465 func (group *Group) processPrivateMessage(privMsg *privateMessage) ([]byte, bool, error) {
466 cs := group.groupContext.cipherSuite
467
468 if !privMsg.groupID.Equal(group.groupContext.groupID) {
469 return nil, false, fmt.Errorf("mls: message group ID mismatch")
470 }
471 if privMsg.epoch != group.groupContext.epoch {
472 return nil, false, fmt.Errorf("mls: epoch mismatch: got %v, want %v", privMsg.epoch, group.groupContext.epoch)
473 }
474
475 senderDataSecret, err := cs.deriveSecret(group.epochSecret, secretLabelSenderData)
476 if err != nil {
477 return nil, false, fmt.Errorf("failed to derive sender data secret: %v", err)
478 }
479
480 senderData, err := privMsg.decryptSenderData(cs, senderDataSecret)
481 if err != nil {
482 return nil, false, fmt.Errorf("failed to decrypt sender data: %v", err)
483 }
484
485 encryptionSecret, err := cs.deriveSecret(group.epochSecret, secretLabelEncryption)
486 if err != nil {
487 return nil, false, fmt.Errorf("failed to derive encryption secret: %v", err)
488 }
489
490 secretTree, err := deriveSecretTree(cs, group.tree.numLeaves(), encryptionSecret)
491 if err != nil {
492 return nil, false, fmt.Errorf("failed to erive secret tree: %v", err)
493 }
494
495 label := ratchetLabelFromContentType(privMsg.contentType)
496 secret, err := secretTree.deriveRatchetRoot(cs, senderData.leafIndex.nodeIndex(), label)
497 if err != nil {
498 return nil, false, fmt.Errorf("failed to derive secret ratchet tree root: %v", err)
499 }
500
501 // TODO: limit number of iterations
502 // TODO: erase knowledge about used generations to ensure forward secrecy
503 for secret.generation != senderData.generation {
504 secret, err = secret.deriveNext(cs)
505 if err != nil {
506 return nil, false, fmt.Errorf("failed to derive next ratchet secret: %v", err)
507 }
508 }
509
510 privContent, err := privMsg.decryptContent(cs, secret, senderData.reuseGuard)
511 if err != nil {
512 return nil, false, fmt.Errorf("failed to decrypt private message content: %v", err)
513 }
514
515 signerNode := group.tree.getLeaf(senderData.leafIndex)
516 if signerNode == nil {
517 return nil, false, fmt.Errorf("mls: signer node is blank")
518 }
519
520 authContent := privMsg.authenticatedContent(senderData, privContent)
521 if !authContent.verifySignature(signerNode.signatureKey, &group.groupContext) {
522 return nil, false, fmt.Errorf("failed to verify private message content signature: %v", err)
523 }
524
525 selfSent := senderData.leafIndex == group.myLeafIndex
526
527 switch authContent.content.contentType {
528 case contentTypeProposal:
529 return nil, false, group.processProposal(authContent)
530 case contentTypeCommit:
531 return nil, false, group.processCommit(authContent, nil, nil, nil)
532 case contentTypeApplication:
533 return authContent.content.applicationData, selfSent, nil
534 default:
535 // TODO: support other content types
536 return nil, false, fmt.Errorf("mls: unsupported content type: %v", authContent.content.contentType)
537 }
538 }
539
540 func (group *Group) processProposal(authContent *authenticatedContent) error {
541 if authContent.content.contentType != contentTypeProposal {
542 panic("mls: expected a proposal")
543 }
544 proposal := authContent.content.proposal
545
546 ref, err := authContent.generateProposalRef(group.groupContext.cipherSuite)
547 if err != nil {
548 return fmt.Errorf("failed to generate proposal ref: %v", err)
549 }
550
551 group.pendingProposals = append(group.pendingProposals, pendingProposal{
552 ref: ref,
553 proposal: proposal,
554 sender: authContent.content.sender.leafIndex,
555 })
556 return nil
557 }
558
559 func (group *Group) processCommit(authContent *authenticatedContent, pskIDs []preSharedKeyID, psks [][]byte, now func() time.Time) error {
560 cs := group.groupContext.cipherSuite
561 senderLeafIndex := authContent.content.sender.leafIndex
562
563 if authContent.content.contentType != contentTypeCommit {
564 panic("mls: expected a commit")
565 }
566 commit := authContent.content.commit
567
568 proposals, senders, err := resolveProposals(commit.proposals, senderLeafIndex, group.pendingProposals)
569 if err != nil {
570 return err
571 }
572
573 if err := verifyProposalList(proposals, senders, senderLeafIndex); err != nil {
574 return fmt.Errorf("failed to verify proposals: %v", err)
575 }
576
577 for _, prop := range proposals {
578 if prop.proposalType == proposalTypeAdd {
579 if err := prop.add.keyPackage.verify(&group.groupContext); err != nil {
580 return fmt.Errorf("failed to verify add proposal: %v", err)
581 }
582 }
583 }
584
585 // TODO: additional proposal list checks
586
587 if proposalListNeedsPath(proposals) && commit.path == nil {
588 return fmt.Errorf("mls: commit is missing update path but required by proposal list")
589 }
590
591 newGroupCtx := group.groupContext
592 newGroupCtx.epoch++
593
594 newTree := group.tree.copy()
595 newTree.apply(proposals, senders)
596
597 newPrivTree := make([]hpkePrivateKey, len(newTree))
598 for i := range group.tree {
599 if i < len(newPrivTree) {
600 newPrivTree[i] = group.privTree[i]
601 }
602 }
603
604 _, kdf, _ := cs.hpke().Params()
605 commitSecret := make([]byte, kdf.ExtractSize())
606 if commit.path != nil {
607 if commit.path.leafNode.leafNodeSource != leafNodeSourceCommit {
608 return fmt.Errorf("mls: commit path leaf node source must be commit")
609 }
610
611 // TODO: check tree length
612 senderNode := newTree.getLeaf(senderLeafIndex)
613
614 // The same signature key can be re-used, but the encryption key
615 // must change
616 signatureKeys, encryptionKeys := newTree.keys()
617 delete(signatureKeys, string(senderNode.signatureKey))
618 err := commit.path.leafNode.verify(&leafNodeVerifyOptions{
619 cipherSuite: cs,
620 groupID: group.groupContext.groupID,
621 leafIndex: senderLeafIndex,
622 supportedCreds: newTree.supportedCreds(),
623 signatureKeys: signatureKeys,
624 encryptionKeys: encryptionKeys,
625 now: now,
626 })
627 if err != nil {
628 return fmt.Errorf("failed to verify leaf node: %v", err)
629 }
630
631 for _, updateNode := range commit.path.nodes {
632 if _, dup := encryptionKeys[string(updateNode.encryptionKey)]; dup {
633 return fmt.Errorf("mls: encryption key in update path already used in ratchet tree")
634 }
635 }
636
637 if err := newTree.mergeUpdatePath(cs, senderLeafIndex, commit.path); err != nil {
638 return fmt.Errorf("failed to merge update path in ratchet tree: %v", err)
639 }
640
641 newGroupCtx.treeHash, err = newTree.computeRootTreeHash(cs)
642 if err != nil {
643 return fmt.Errorf("failed to compute root tree hash: %v", err)
644 }
645
646 // TODO: update group context extensions
647
648 commitSecret, err = newTree.decryptPathSecrets(cs, &newGroupCtx, senderLeafIndex, group.myLeafIndex, commit.path, newPrivTree)
649 if err != nil {
650 return fmt.Errorf("failed to decrypt path secrets: %v", err)
651 }
652 } else {
653 // TODO: only recompute parts of the tree affected by proposals
654 newGroupCtx.treeHash, err = newTree.computeRootTreeHash(cs)
655 if err != nil {
656 return fmt.Errorf("failed to compute root tree hash: %v", err)
657 }
658 }
659
660 newGroupCtx.confirmedTranscriptHash, err = authContent.confirmedTranscriptHashInput().hash(cs, group.interimTranscriptHash)
661 if err != nil {
662 return fmt.Errorf("failed to hash confirmed transcript hash input: %v", err)
663 }
664
665 newInterimTranscriptHash, err := nextInterimTranscriptHash(cs, newGroupCtx.confirmedTranscriptHash, authContent.auth.confirmationTag)
666 if err != nil {
667 return fmt.Errorf("failed to compute next interim transcript hash: %v", err)
668 }
669
670 newJoinerSecret, err := newGroupCtx.extractJoinerSecret(group.initSecret, commitSecret)
671 if err != nil {
672 return fmt.Errorf("failed to extract joined secret: %v", err)
673 }
674
675 newPSKSecret, err := extractPSKSecret(cs, pskIDs, psks)
676 if err != nil {
677 return fmt.Errorf("failed to extract PSK secret: %v", err)
678 }
679
680 newEpochSecret, err := newGroupCtx.extractEpochSecret(newJoinerSecret, newPSKSecret)
681 if err != nil {
682 return fmt.Errorf("failed to extract epoch secret: %v", err)
683 }
684
685 newInitSecret, err := cs.deriveSecret(newEpochSecret, secretLabelInit)
686 if err != nil {
687 return fmt.Errorf("failed to erive init secret: %v", err)
688 }
689
690 group.tree = newTree
691 group.privTree = newPrivTree
692 group.groupContext = newGroupCtx
693 group.interimTranscriptHash = newInterimTranscriptHash
694 group.pskSecret = newPSKSecret
695 group.epochSecret = newEpochSecret
696 group.initSecret = newInitSecret
697 group.pendingProposals = nil // TODO: only clear proposals we've consumed
698 return nil
699 }
700
701 func resolveProposals(proposalOrRefs []proposalOrRef, senderLeafIndex leafIndex, pendingProposals []pendingProposal) ([]proposal, []leafIndex, error) {
702 var (
703 proposals []proposal
704 senders []leafIndex
705 )
706 for _, propOrRef := range proposalOrRefs {
707 switch propOrRef.typ {
708 case proposalOrRefTypeProposal:
709 proposals = append(proposals, *propOrRef.proposal)
710 senders = append(senders, senderLeafIndex)
711 case proposalOrRefTypeReference:
712 var found bool
713 for _, pp := range pendingProposals {
714 if pp.ref.Equal(propOrRef.reference) {
715 found = true
716 proposals = append(proposals, *pp.proposal)
717 senders = append(senders, pp.sender)
718 break
719 }
720 }
721 if !found {
722 return nil, nil, fmt.Errorf("mls: cannot find proposal reference: %v", propOrRef.reference)
723 }
724 }
725 }
726
727 return proposals, senders, nil
728 }
729
730 // CreateWelcome creates a new welcome message, inviting new members to the
731 // group.
732 //
733 // The welcome message should be sent to the new members. Alongside the welcome
734 // message, a raw MLS message is returned and must be consumed by all existing
735 // members of the group to add the new members.
736 func (group *Group) CreateWelcome(keyPkgs []KeyPackage) (*Welcome, []byte, error) {
737 // TODO: missing steps from section 12.4.1
738 cs := group.groupContext.cipherSuite
739
740 if len(keyPkgs) == 0 {
741 panic("mls: expected at least one key package")
742 }
743
744 proposals := make([]proposal, len(keyPkgs))
745 proposalOrRefs := make([]proposalOrRef, len(keyPkgs))
746 for i, keyPkg := range keyPkgs {
747 proposals[i] = proposal{
748 proposalType: proposalTypeAdd,
749 add: &add{keyPackage: keyPkg},
750 }
751 proposalOrRefs[i] = proposalOrRef{
752 typ: proposalOrRefTypeProposal,
753 proposal: &proposals[i],
754 }
755 }
756
757 // TODO: check proposal list validity per section 12.2
758 commit := commit{proposals: proposalOrRefs}
759
760 newGroupCtx := group.groupContext
761 newGroupCtx.epoch++
762
763 newTree := group.tree.copy()
764 newTree.apply(proposals, []leafIndex{group.myLeafIndex})
765
766 // TODO: only recompute parts of the tree affected by proposals
767 var err error
768 newGroupCtx.treeHash, err = newTree.computeRootTreeHash(cs)
769 if err != nil {
770 return nil, nil, fmt.Errorf("failed to compute root tree hash: %v", err)
771 }
772
773 _, kdf, _ := cs.hpke().Params()
774 commitSecret := make([]byte, kdf.ExtractSize())
775
776 pskSecret, err := extractPSKSecret(cs, nil, nil)
777 if err != nil {
778 return nil, nil, fmt.Errorf("failed to extract PSK secret: %v", err)
779 }
780
781 framedContent := framedContent{
782 groupID: group.groupContext.groupID,
783 epoch: group.groupContext.epoch,
784 sender: sender{
785 senderType: senderTypeMember,
786 leafIndex: group.myLeafIndex,
787 },
788 contentType: contentTypeCommit,
789 commit: &commit,
790 }
791
792 public := false // TODO: add option to enable this
793 var (
794 authContent *authenticatedContent
795 authData *framedContentAuthData
796 pubMsg *publicMessage
797 privContent *privateMessageContent
798 )
799 if public {
800 pubMsg, err = signPublicMessage(cs, group.signaturePriv, &framedContent, &group.groupContext)
801 if err != nil {
802 return nil, nil, fmt.Errorf("failed to sign public message: %v", err)
803 }
804 authContent = pubMsg.authenticatedContent()
805 authData = &pubMsg.auth
806 } else {
807 privContent, err = signPrivateMessageContent(cs, group.signaturePriv, &framedContent, &group.groupContext)
808 if err != nil {
809 return nil, nil, fmt.Errorf("failed to sign private message: %v", err)
810 }
811 authContent = privContent.authenticatedContent(&framedContent)
812 authData = &privContent.auth
813 }
814
815 newGroupCtx.confirmedTranscriptHash, err = authContent.confirmedTranscriptHashInput().hash(cs, group.interimTranscriptHash)
816 if err != nil {
817 return nil, nil, fmt.Errorf("failed to hash confirmed transcript hash input: %v", err)
818 }
819
820 joinerSecret, err := newGroupCtx.extractJoinerSecret(group.initSecret, commitSecret)
821 if err != nil {
822 return nil, nil, fmt.Errorf("failed to extract joiner secret: %v", err)
823 }
824
825 epochSecret, err := newGroupCtx.extractEpochSecret(joinerSecret, pskSecret)
826 if err != nil {
827 return nil, nil, fmt.Errorf("failed to extract epoch secret: %v", err)
828 }
829
830 confirmationTag, err := newGroupCtx.signConfirmationTag(epochSecret)
831 if err != nil {
832 return nil, nil, fmt.Errorf("failed to sign confirmation tag: %v", err)
833 }
834 authData.confirmationTag = confirmationTag
835
836 rawTree, err := marshal(newTree)
837 if err != nil {
838 return nil, nil, fmt.Errorf("failed to marshal ratchet tree: %v", err)
839 }
840
841 newGroupInfo := groupInfo{
842 groupContext: newGroupCtx,
843 confirmationTag: confirmationTag,
844 signer: group.myLeafIndex,
845 extensions: []extension{
846 {
847 extensionType: extensionTypeRatchetTree,
848 extensionData: rawTree,
849 },
850 },
851 }
852 if err := newGroupInfo.sign(group.signaturePriv); err != nil {
853 return nil, nil, fmt.Errorf("failed to sign group info: %v", err)
854 }
855
856 encryptedGroupInfo, err := newGroupInfo.encrypt(joinerSecret, pskSecret)
857 if err != nil {
858 return nil, nil, fmt.Errorf("failed to encrypt group info: %v", err)
859 }
860
861 groupSecrets := groupSecrets{joinerSecret: joinerSecret}
862 encGroupSecrets := make([]encryptedGroupSecrets, len(keyPkgs))
863 for i, keyPkg := range keyPkgs {
864 keyPkgRef, err := keyPkg.GenerateRef()
865 if err != nil {
866 return nil, nil, fmt.Errorf("failed to generate key package ref: %v", err)
867 }
868
869 rawEncryptedGroupSecrets, err := groupSecrets.encrypt(cs, keyPkg.initKey, encryptedGroupInfo)
870 if err != nil {
871 return nil, nil, fmt.Errorf("failed to encrypt group secrets: %v", err)
872 }
873
874 encGroupSecrets[i] = encryptedGroupSecrets{
875 newMember: keyPkgRef,
876 encryptedGroupSecrets: *rawEncryptedGroupSecrets,
877 }
878 }
879
880 var rawMsg []byte
881 if public {
882 rawMsg, err = group.signPublicMessageMembershipTag(pubMsg)
883 if err != nil {
884 return nil, nil, err
885 }
886 } else {
887 rawMsg, err = group.encryptPrivateMessage(&framedContent, privContent)
888 if err != nil {
889 return nil, nil, fmt.Errorf("failed to encrypt private message: %v", err)
890 }
891 }
892
893 return &Welcome{
894 cipherSuite: cs,
895 secrets: encGroupSecrets,
896 encryptedGroupInfo: encryptedGroupInfo,
897 }, rawMsg, nil
898 }
899
900 // CreateApplicationMessage creates a new encrypted application message for the
901 // group. The message contains an arbitrary application-specific payload.
902 func (group *Group) CreateApplicationMessage(data []byte) ([]byte, error) {
903 cs := group.groupContext.cipherSuite
904
905 framedContent := framedContent{
906 groupID: group.groupContext.groupID,
907 epoch: group.groupContext.epoch,
908 sender: sender{
909 senderType: senderTypeMember,
910 leafIndex: group.myLeafIndex,
911 },
912 contentType: contentTypeApplication,
913 applicationData: data,
914 }
915 privContent, err := signPrivateMessageContent(cs, group.signaturePriv, &framedContent, &group.groupContext)
916 if err != nil {
917 return nil, fmt.Errorf("failed to sign private message: %v", err)
918 }
919
920 return group.encryptPrivateMessage(&framedContent, privContent)
921 }
922
923 func (group *Group) encryptPrivateMessage(framedContent *framedContent, privContent *privateMessageContent) ([]byte, error) {
924 cs := group.groupContext.cipherSuite
925
926 senderData, err := newSenderData(group.myLeafIndex, 0) // TODO: set generation > 0
927 if err != nil {
928 return nil, fmt.Errorf("failed to create sender data: %v", err)
929 }
930
931 encryptionSecret, err := cs.deriveSecret(group.epochSecret, secretLabelEncryption)
932 if err != nil {
933 return nil, fmt.Errorf("failed to derive encryption secret: %v", err)
934 }
935
936 secretTree, err := deriveSecretTree(cs, group.tree.numLeaves(), encryptionSecret)
937 if err != nil {
938 return nil, fmt.Errorf("failed to erive secret tree: %v", err)
939 }
940
941 label := ratchetLabelFromContentType(framedContent.contentType)
942 secret, err := secretTree.deriveRatchetRoot(cs, group.myLeafIndex.nodeIndex(), label)
943 if err != nil {
944 return nil, fmt.Errorf("failed to derive secret ratchet tree root: %v", err)
945 }
946
947 senderDataSecret, err := cs.deriveSecret(group.epochSecret, secretLabelSenderData)
948 if err != nil {
949 return nil, fmt.Errorf("failed to derive sender data secret: %v", err)
950 }
951
952 privMsg, err := encryptPrivateMessage(cs, secret, senderDataSecret, framedContent, privContent, senderData)
953 if err != nil {
954 return nil, fmt.Errorf("failed to encrypt private message: %v", err)
955 }
956
957 rawMsg, err := marshal(&mlsMessage{
958 version: protocolVersionMLS10,
959 wireFormat: wireFormatMLSPrivateMessage,
960 privateMessage: privMsg,
961 })
962 if err != nil {
963 return nil, fmt.Errorf("failed to marshal private message: %v", err)
964 }
965
966 return rawMsg, nil
967 }
968
969 func (group *Group) signPublicMessageMembershipTag(pubMsg *publicMessage) ([]byte, error) {
970 cs := group.groupContext.cipherSuite
971
972 membershipKey, err := group.groupContext.cipherSuite.deriveSecret(group.epochSecret, secretLabelMembership)
973 if err != nil {
974 return nil, fmt.Errorf("failed to derive membership key: %v", err)
975 }
976 if err := pubMsg.signMembershipTag(cs, membershipKey, &group.groupContext); err != nil {
977 return nil, fmt.Errorf("failed to sign public message membership tag: %v", err)
978 }
979
980 rawMsg, err := marshal(&mlsMessage{
981 version: protocolVersionMLS10,
982 wireFormat: wireFormatMLSPublicMessage,
983 publicMessage: pubMsg,
984 })
985 if err != nil {
986 return nil, fmt.Errorf("failed to marshal public message: %v", err)
987 }
988
989 return rawMsg, nil
990 }
991
992 type commit struct {
993 proposals []proposalOrRef
994 path *updatePath // optional
995 }
996
997 func (c *commit) unmarshal(s *cryptobyte.String) error {
998 *c = commit{}
999
1000 err := readVector(s, func(s *cryptobyte.String) error {
1001 var propOrRef proposalOrRef
1002 if err := propOrRef.unmarshal(s); err != nil {
1003 return err
1004 }
1005 c.proposals = append(c.proposals, propOrRef)
1006 return nil
1007 })
1008 if err != nil {
1009 return err
1010 }
1011
1012 var hasPath bool
1013 if !readOptional(s, &hasPath) {
1014 return io.ErrUnexpectedEOF
1015 } else if hasPath {
1016 c.path = new(updatePath)
1017 if err := c.path.unmarshal(s); err != nil {
1018 return err
1019 }
1020 }
1021
1022 return nil
1023 }
1024
1025 func (c *commit) marshal(b *cryptobyte.Builder) {
1026 writeVector(b, len(c.proposals), func(b *cryptobyte.Builder, i int) {
1027 c.proposals[i].marshal(b)
1028 })
1029 writeOptional(b, c.path != nil)
1030 if c.path != nil {
1031 c.path.marshal(b)
1032 }
1033 }
1034
1035 type groupInfo struct {
1036 groupContext groupContext
1037 extensions []extension
1038 confirmationTag []byte
1039 signer leafIndex
1040 signature []byte
1041 }
1042
1043 func (info *groupInfo) unmarshal(s *cryptobyte.String) error {
1044 *info = groupInfo{}
1045
1046 if err := info.groupContext.unmarshal(s); err != nil {
1047 return err
1048 }
1049
1050 exts, err := unmarshalExtensionVec(s)
1051 if err != nil {
1052 return err
1053 }
1054 info.extensions = exts
1055
1056 if !readOpaqueVec(s, &info.confirmationTag) || !s.ReadUint32((*uint32)(&info.signer)) || !readOpaqueVec(s, &info.signature) {
1057 return err
1058 }
1059
1060 return nil
1061 }
1062
1063 func (info *groupInfo) marshal(b *cryptobyte.Builder) {
1064 (*groupInfoTBS)(info).marshal(b)
1065 writeOpaqueVec(b, info.signature)
1066 }
1067
1068 func (info *groupInfo) verifySignature(signerPub signaturePublicKey) bool {
1069 cs := info.groupContext.cipherSuite
1070 tbs, err := marshal((*groupInfoTBS)(info))
1071 if err != nil {
1072 return false
1073 }
1074 return cs.verifyWithLabel(signerPub, []byte("GroupInfoTBS"), tbs, info.signature)
1075 }
1076
1077 func (info *groupInfo) sign(signerPriv signaturePrivateKey) error {
1078 cs := info.groupContext.cipherSuite
1079 tbs, err := marshal((*groupInfoTBS)(info))
1080 if err != nil {
1081 return err
1082 }
1083 sig, err := cs.signWithLabel(signerPriv, []byte("GroupInfoTBS"), tbs)
1084 if err != nil {
1085 return err
1086 }
1087 info.signature = sig
1088 return nil
1089 }
1090
1091 func (info *groupInfo) verifyConfirmationTag(joinerSecret, pskSecret []byte) bool {
1092 cs := info.groupContext.cipherSuite
1093 epochSecret, err := info.groupContext.extractEpochSecret(joinerSecret, pskSecret)
1094 if err != nil {
1095 return false
1096 }
1097 confirmationKey, err := cs.deriveSecret(epochSecret, secretLabelConfirm)
1098 if err != nil {
1099 return false
1100 }
1101 return cs.verifyMAC(confirmationKey, info.groupContext.confirmedTranscriptHash, info.confirmationTag)
1102 }
1103
1104 func (info *groupInfo) encrypt(joinerSecret, pskSecret []byte) ([]byte, error) {
1105 cs := info.groupContext.cipherSuite
1106 _, _, aead := cs.hpke().Params()
1107
1108 welcomeSecret, err := extractWelcomeSecret(cs, joinerSecret, pskSecret)
1109 if err != nil {
1110 return nil, err
1111 }
1112
1113 welcomeNonce, err := cs.expandWithLabel(welcomeSecret, []byte("nonce"), nil, uint16(aead.NonceSize()))
1114 if err != nil {
1115 return nil, err
1116 }
1117 welcomeKey, err := cs.expandWithLabel(welcomeSecret, []byte("key"), nil, uint16(aead.KeySize()))
1118 if err != nil {
1119 return nil, err
1120 }
1121
1122 cipher, err := aead.New(welcomeKey)
1123 if err != nil {
1124 return nil, err
1125 }
1126
1127 rawGroupInfo, err := marshal(info)
1128 if err != nil {
1129 return nil, err
1130 }
1131
1132 return cipher.Seal(nil, welcomeNonce, rawGroupInfo, nil), nil
1133 }
1134
1135 type groupInfoTBS groupInfo
1136
1137 func (info *groupInfoTBS) marshal(b *cryptobyte.Builder) {
1138 info.groupContext.marshal(b)
1139 marshalExtensionVec(b, info.extensions)
1140 writeOpaqueVec(b, info.confirmationTag)
1141 b.AddUint32(uint32(info.signer))
1142 }
1143
1144 type groupSecrets struct {
1145 joinerSecret []byte
1146 pathSecret []byte // optional
1147 psks []preSharedKeyID
1148 }
1149
1150 func (sec *groupSecrets) unmarshal(s *cryptobyte.String) error {
1151 *sec = groupSecrets{}
1152
1153 if !readOpaqueVec(s, &sec.joinerSecret) {
1154 return io.ErrUnexpectedEOF
1155 }
1156
1157 var hasPathSecret bool
1158 if !readOptional(s, &hasPathSecret) {
1159 return io.ErrUnexpectedEOF
1160 } else if hasPathSecret && !readOpaqueVec(s, &sec.pathSecret) {
1161 return io.ErrUnexpectedEOF
1162 }
1163
1164 return readVector(s, func(s *cryptobyte.String) error {
1165 var psk preSharedKeyID
1166 if err := psk.unmarshal(s); err != nil {
1167 return err
1168 }
1169 sec.psks = append(sec.psks, psk)
1170 return nil
1171 })
1172 }
1173
1174 func (sec *groupSecrets) marshal(b *cryptobyte.Builder) {
1175 writeOpaqueVec(b, sec.joinerSecret)
1176
1177 writeOptional(b, sec.pathSecret != nil)
1178 if sec.pathSecret != nil {
1179 writeOpaqueVec(b, sec.pathSecret)
1180 }
1181
1182 writeVector(b, len(sec.psks), func(b *cryptobyte.Builder, i int) {
1183 sec.psks[i].marshal(b)
1184 })
1185 }
1186
1187 // verifySingleReInitOrBranchPSK verifies that at most one key has type
1188 // resumption with usage reinit or branch.
1189 func (sec *groupSecrets) verifySingleReinitOrBranchPSK() bool {
1190 n := 0
1191 for _, pskID := range sec.psks {
1192 if pskID.pskType != pskTypeResumption {
1193 continue
1194 }
1195 switch pskID.usage {
1196 case resumptionPSKUsageReinit, resumptionPSKUsageBranch:
1197 n++
1198 }
1199 }
1200 return n <= 1
1201 }
1202
1203 func (sec *groupSecrets) encrypt(cs CipherSuite, initKey hpkePublicKey, encryptedGroupInfo []byte) (*hpkeCiphertext, error) {
1204 rawGroupSecrets, err := marshal(sec)
1205 if err != nil {
1206 return nil, err
1207 }
1208
1209 kemOutput, ciphertext, err := cs.encryptWithLabel(initKey, []byte("Welcome"), encryptedGroupInfo, rawGroupSecrets)
1210 if err != nil {
1211 return nil, err
1212 }
1213
1214 return &hpkeCiphertext{
1215 kemOutput: kemOutput,
1216 ciphertext: ciphertext,
1217 }, nil
1218 }
1219
1220 // A Welcome message includes secret keying information necessary to join a
1221 // group.
1222 type Welcome struct {
1223 cipherSuite CipherSuite
1224 secrets []encryptedGroupSecrets
1225 encryptedGroupInfo []byte
1226 }
1227
1228 // UnmarshalWelcome reads a welcome message.
1229 func UnmarshalWelcome(raw []byte) (*Welcome, error) {
1230 var msg mlsMessage
1231 if err := unmarshal(raw, &msg); err != nil {
1232 return nil, err
1233 } else if msg.wireFormat != wireFormatMLSWelcome {
1234 return nil, fmt.Errorf("mls: expected a key package message, got wire format %v", msg.wireFormat)
1235 }
1236 return msg.welcome, nil
1237 }
1238
1239 // Bytes encodes the welcome message.
1240 func (w *Welcome) Bytes() []byte {
1241 raw, err := marshal(&mlsMessage{
1242 version: protocolVersionMLS10,
1243 wireFormat: wireFormatMLSWelcome,
1244 welcome: w,
1245 })
1246 if err != nil {
1247 // should never happen
1248 panic(fmt.Errorf("mls: failed to marshal welcome message: %v", err))
1249 }
1250 return raw
1251 }
1252
1253 func (w *Welcome) unmarshal(s *cryptobyte.String) error {
1254 *w = Welcome{}
1255
1256 if !s.ReadUint16((*uint16)(&w.cipherSuite)) {
1257 return io.ErrUnexpectedEOF
1258 }
1259
1260 err := readVector(s, func(s *cryptobyte.String) error {
1261 var sec encryptedGroupSecrets
1262 if err := sec.unmarshal(s); err != nil {
1263 return err
1264 }
1265 w.secrets = append(w.secrets, sec)
1266 return nil
1267 })
1268 if err != nil {
1269 return err
1270 }
1271
1272 if !readOpaqueVec(s, &w.encryptedGroupInfo) {
1273 return io.ErrUnexpectedEOF
1274 }
1275
1276 return nil
1277 }
1278
1279 func (w *Welcome) marshal(b *cryptobyte.Builder) {
1280 b.AddUint16(uint16(w.cipherSuite))
1281 writeVector(b, len(w.secrets), func(b *cryptobyte.Builder, i int) {
1282 w.secrets[i].marshal(b)
1283 })
1284 writeOpaqueVec(b, w.encryptedGroupInfo)
1285 }
1286
1287 // NewMembers returns the list of key package references this welcome message
1288 // contains secret keying information for.
1289 func (w *Welcome) NewMembers() []KeyPackageRef {
1290 refs := make([]KeyPackageRef, len(w.secrets))
1291 for i, sec := range w.secrets {
1292 refs[i] = sec.newMember
1293 }
1294 return refs
1295 }
1296
1297 func (w *Welcome) findSecret(ref KeyPackageRef) *encryptedGroupSecrets {
1298 for i, sec := range w.secrets {
1299 if sec.newMember.Equal(ref) {
1300 return &w.secrets[i]
1301 }
1302 }
1303 return nil
1304 }
1305
1306 func (w *Welcome) decryptGroupSecrets(ref KeyPackageRef, initKeyPriv hpkePrivateKey) (*groupSecrets, error) {
1307 cs := w.cipherSuite
1308
1309 sec := w.findSecret(ref)
1310 if sec == nil {
1311 return nil, fmt.Errorf("mls: encrypted group secrets not found for provided key package ref")
1312 }
1313
1314 rawGroupSecrets, err := cs.decryptWithLabel(initKeyPriv, []byte("Welcome"), w.encryptedGroupInfo, sec.encryptedGroupSecrets.kemOutput, sec.encryptedGroupSecrets.ciphertext)
1315 if err != nil {
1316 return nil, err
1317 }
1318 var groupSecrets groupSecrets
1319 if err := unmarshal(rawGroupSecrets, &groupSecrets); err != nil {
1320 return nil, err
1321 }
1322
1323 return &groupSecrets, err
1324 }
1325
1326 func (w *Welcome) decryptGroupInfo(joinerSecret, pskSecret []byte) (*groupInfo, error) {
1327 cs := w.cipherSuite
1328 _, _, aead := cs.hpke().Params()
1329
1330 welcomeSecret, err := extractWelcomeSecret(cs, joinerSecret, pskSecret)
1331 if err != nil {
1332 return nil, err
1333 }
1334
1335 welcomeNonce, err := cs.expandWithLabel(welcomeSecret, []byte("nonce"), nil, uint16(aead.NonceSize()))
1336 if err != nil {
1337 return nil, err
1338 }
1339 welcomeKey, err := cs.expandWithLabel(welcomeSecret, []byte("key"), nil, uint16(aead.KeySize()))
1340 if err != nil {
1341 return nil, err
1342 }
1343
1344 welcomeCipher, err := aead.New(welcomeKey)
1345 if err != nil {
1346 return nil, err
1347 }
1348 rawGroupInfo, err := welcomeCipher.Open(nil, welcomeNonce, w.encryptedGroupInfo, nil)
1349 if err != nil {
1350 return nil, err
1351 }
1352
1353 var groupInfo groupInfo
1354 if err := unmarshal(rawGroupInfo, &groupInfo); err != nil {
1355 return nil, err
1356 }
1357
1358 return &groupInfo, nil
1359 }
1360
1361 type encryptedGroupSecrets struct {
1362 newMember KeyPackageRef
1363 encryptedGroupSecrets hpkeCiphertext
1364 }
1365
1366 func (sec *encryptedGroupSecrets) unmarshal(s *cryptobyte.String) error {
1367 *sec = encryptedGroupSecrets{}
1368 if !readOpaqueVec(s, (*[]byte)(&sec.newMember)) {
1369 return io.ErrUnexpectedEOF
1370 }
1371 if err := sec.encryptedGroupSecrets.unmarshal(s); err != nil {
1372 return err
1373 }
1374 return nil
1375 }
1376
1377 func (sec *encryptedGroupSecrets) marshal(b *cryptobyte.Builder) {
1378 writeOpaqueVec(b, []byte(sec.newMember))
1379 sec.encryptedGroupSecrets.marshal(b)
1380 }
1381