group.go raw

   1  package mls
   2  
   3  import (
   4  	"crypto/rand"
   5  	"fmt"
   6  	"io"
   7  	"time"
   8  
   9  	"golang.org/x/crypto/cryptobyte"
  10  )
  11  
  12  type pendingProposal struct {
  13  	ref      proposalRef
  14  	proposal *proposal
  15  	sender   leafIndex
  16  }
  17  
  18  // A Group is a high-level API for an MLS group.
  19  type Group struct {
  20  	tree         ratchetTree
  21  	groupContext groupContext
  22  
  23  	interimTranscriptHash []byte
  24  	pskSecret             []byte
  25  	epochSecret           []byte
  26  	initSecret            []byte
  27  
  28  	myLeafIndex   leafIndex
  29  	privTree      []hpkePrivateKey
  30  	signaturePriv signaturePrivateKey
  31  
  32  	pendingProposals []pendingProposal
  33  }
  34  
  35  // Epoch returns the current MLS epoch number. The epoch increments on every
  36  // Commit (including application message encrypt/decrypt that triggers ratcheting).
  37  func (group *Group) Epoch() uint64 {
  38  	return group.groupContext.epoch
  39  }
  40  
  41  // ExporterSecret derives the exporter secret from the current epoch secret.
  42  // This is needed by NIP-EE to derive the NIP-44 conversation key for
  43  // encrypting kind 445 group message content.
  44  func (group *Group) ExporterSecret() ([]byte, error) {
  45  	return group.groupContext.cipherSuite.deriveSecret(
  46  		group.epochSecret, secretLabelExporter,
  47  	)
  48  }
  49  
  50  // GroupContextExtensions returns the extensions from the group context.
  51  // Use this to extract application-specific data like NostrGroupData (0xf2ee).
  52  func (group *Group) GroupContextExtensions() []extension {
  53  	return group.groupContext.extensions
  54  }
  55  
  56  // FindGroupContextExtension returns the data for the extension with the given
  57  // type, or nil if not found.
  58  func (group *Group) FindGroupContextExtension(t extensionType) []byte {
  59  	return findExtensionData(group.groupContext.extensions, t)
  60  }
  61  
  62  // Marshal serializes the full Group state (including private keys and epoch
  63  // secrets) so it can be persisted and restored later. This is NOT a wire
  64  // format — it's for local storage only. The output contains sensitive key
  65  // material and must be encrypted at rest.
  66  func (group *Group) Marshal() ([]byte, error) {
  67  	var b cryptobyte.Builder
  68  
  69  	// 1. groupContext (TLS-serialized)
  70  	group.groupContext.marshal(&b)
  71  
  72  	// 2. ratchetTree
  73  	group.tree.marshal(&b)
  74  
  75  	// 3. Secrets
  76  	writeOpaqueVec(&b, group.interimTranscriptHash)
  77  	writeOpaqueVec(&b, group.pskSecret)
  78  	writeOpaqueVec(&b, group.epochSecret)
  79  	writeOpaqueVec(&b, group.initSecret)
  80  
  81  	// 4. My identity within the group
  82  	b.AddUint32(uint32(group.myLeafIndex))
  83  	writeOpaqueVec(&b, []byte(group.signaturePriv))
  84  
  85  	// 5. Private tree (HPKE private keys, indexed by node position)
  86  	writeVector(&b, len(group.privTree), func(b *cryptobyte.Builder, i int) {
  87  		writeOpaqueVec(b, []byte(group.privTree[i]))
  88  	})
  89  
  90  	return b.Bytes()
  91  }
  92  
  93  // UnmarshalGroup restores a Group from bytes produced by Marshal.
  94  func UnmarshalGroup(raw []byte) (*Group, error) {
  95  	s := cryptobyte.String(raw)
  96  	g := &Group{}
  97  
  98  	// 1. groupContext
  99  	if err := g.groupContext.unmarshal(&s); err != nil {
 100  		return nil, fmt.Errorf("mls: unmarshal group context: %w", err)
 101  	}
 102  
 103  	// 2. ratchetTree
 104  	if err := g.tree.unmarshal(&s); err != nil {
 105  		return nil, fmt.Errorf("mls: unmarshal ratchet tree: %w", err)
 106  	}
 107  
 108  	// 3. Secrets
 109  	if !readOpaqueVec(&s, &g.interimTranscriptHash) ||
 110  		!readOpaqueVec(&s, &g.pskSecret) ||
 111  		!readOpaqueVec(&s, &g.epochSecret) ||
 112  		!readOpaqueVec(&s, &g.initSecret) {
 113  		return nil, fmt.Errorf("mls: unmarshal secrets: unexpected EOF")
 114  	}
 115  
 116  	// 4. My identity
 117  	if !s.ReadUint32((*uint32)(&g.myLeafIndex)) {
 118  		return nil, fmt.Errorf("mls: unmarshal leaf index: unexpected EOF")
 119  	}
 120  	var sigPriv []byte
 121  	if !readOpaqueVec(&s, &sigPriv) {
 122  		return nil, fmt.Errorf("mls: unmarshal signature priv: unexpected EOF")
 123  	}
 124  	g.signaturePriv = signaturePrivateKey(sigPriv)
 125  
 126  	// 5. Private tree
 127  	if err := readVector(&s, func(s *cryptobyte.String) error {
 128  		var k []byte
 129  		if !readOpaqueVec(s, &k) {
 130  			return io.ErrUnexpectedEOF
 131  		}
 132  		g.privTree = append(g.privTree, hpkePrivateKey(k))
 133  		return nil
 134  	}); err != nil {
 135  		return nil, fmt.Errorf("mls: unmarshal priv tree: %w", err)
 136  	}
 137  
 138  	return g, nil
 139  }
 140  
 141  // GroupID returns the MLS group ID.
 142  func (group *Group) GroupID() GroupID {
 143  	return group.groupContext.groupID
 144  }
 145  
 146  // DeriveExporter exports keying material from the group's exporter secret
 147  // using the MLS exporter derivation (RFC 9420 Section 8).
 148  func (group *Group) DeriveExporter(label, context []byte, length uint16) ([]byte, error) {
 149  	exporterSecret, err := group.ExporterSecret()
 150  	if err != nil {
 151  		return nil, err
 152  	}
 153  	return deriveExporter(group.groupContext.cipherSuite, exporterSecret, label, context, length)
 154  }
 155  
 156  // GroupOptions configures group creation.
 157  type GroupOptions struct {
 158  	// Extensions are included in the group context. For Marmot, this
 159  	// should include a NostrGroupData extension (0xf2ee).
 160  	Extensions []extension
 161  }
 162  
 163  // CreateGroup creates a new group with a single member.
 164  func CreateGroup(groupID GroupID, keyPairPkg *KeyPairPackage) (*Group, error) {
 165  	return CreateGroupWithOptions(groupID, keyPairPkg, nil)
 166  }
 167  
 168  // CreateGroupWithOptions creates a new group with custom group context extensions.
 169  func CreateGroupWithOptions(groupID GroupID, keyPairPkg *KeyPairPackage, opts *GroupOptions) (*Group, error) {
 170  	cs := keyPairPkg.Public.cipherSuite
 171  
 172  	tree := make(ratchetTree, 1)
 173  	tree.add(&keyPairPkg.Public.leafNode)
 174  
 175  	privTree := make([]hpkePrivateKey, len(tree))
 176  	privTree[0] = keyPairPkg.Private.EncryptionKey
 177  
 178  	treeHash, err := tree.computeRootTreeHash(cs)
 179  	if err != nil {
 180  		return nil, fmt.Errorf("failed to compute root tree hash: %v", err)
 181  	}
 182  
 183  	confirmedTranscriptHash := make([]byte, cs.hash().Size())
 184  
 185  	_, kdf, _ := cs.hpke().Params()
 186  	epochSecret := make([]byte, kdf.ExtractSize())
 187  	if _, err := rand.Read(epochSecret); err != nil {
 188  		return nil, fmt.Errorf("failed to generate epoch secret: %v", err)
 189  	}
 190  
 191  	var ctxExts []extension
 192  	if opts != nil {
 193  		ctxExts = opts.Extensions
 194  	}
 195  
 196  	groupCtx := groupContext{
 197  		version:                 keyPairPkg.Public.version,
 198  		cipherSuite:             keyPairPkg.Public.cipherSuite,
 199  		groupID:                 groupID,
 200  		epoch:                   0,
 201  		treeHash:                treeHash,
 202  		confirmedTranscriptHash: confirmedTranscriptHash,
 203  		extensions:              ctxExts,
 204  	}
 205  
 206  	confirmationTag, err := groupCtx.signConfirmationTag(epochSecret)
 207  	if err != nil {
 208  		return nil, fmt.Errorf("failed to sign confirmation tag: %v", err)
 209  	}
 210  
 211  	interimTranscriptHash, err := nextInterimTranscriptHash(cs, confirmedTranscriptHash, confirmationTag)
 212  	if err != nil {
 213  		return nil, fmt.Errorf("failed to compute initial interim transcript hash: %v", err)
 214  	}
 215  
 216  	pskSecret, err := extractPSKSecret(cs, nil, nil)
 217  	if err != nil {
 218  		return nil, fmt.Errorf("failed to extract PSK secret: %v", err)
 219  	}
 220  
 221  	initSecret, err := groupCtx.cipherSuite.deriveSecret(epochSecret, secretLabelInit)
 222  	if err != nil {
 223  		return nil, fmt.Errorf("failed to derive init secret: %v", err)
 224  	}
 225  
 226  	return &Group{
 227  		tree:                  tree,
 228  		privTree:              privTree,
 229  		myLeafIndex:           0,
 230  		signaturePriv:         keyPairPkg.Private.SignatureKey,
 231  		groupContext:          groupCtx,
 232  		interimTranscriptHash: interimTranscriptHash,
 233  		pskSecret:             pskSecret,
 234  		epochSecret:           epochSecret,
 235  		initSecret:            initSecret,
 236  	}, nil
 237  }
 238  
 239  // GroupFromWelcome creates a new group from a welcome message.
 240  func GroupFromWelcome(welcome *Welcome, keyPairPkg *KeyPairPackage) (*Group, error) {
 241  	keyPkgRef, err := keyPairPkg.Public.GenerateRef()
 242  	if err != nil {
 243  		return nil, fmt.Errorf("failed to generate key package ref: %v", err)
 244  	}
 245  
 246  	groupSecrets, err := welcome.decryptGroupSecrets(keyPkgRef, keyPairPkg.Private.InitKey)
 247  	if err != nil {
 248  		return nil, fmt.Errorf("failed to decrypt group secrets: %v", err)
 249  	}
 250  
 251  	if !groupSecrets.verifySingleReinitOrBranchPSK() {
 252  		return nil, fmt.Errorf("mls: more than one key has usage reinit or branch in group secrets")
 253  	}
 254  
 255  	if len(groupSecrets.psks) != 0 {
 256  		return nil, fmt.Errorf("mls: group secret PSKs are not yet supported")
 257  	}
 258  
 259  	return groupFromSecrets(welcome, keyPairPkg, groupSecrets, nil)
 260  }
 261  
 262  type groupFromSecretsOptions struct {
 263  	rawTree []byte
 264  	psks    [][]byte
 265  	now     func() time.Time
 266  }
 267  
 268  func groupFromSecrets(welcome *Welcome, keyPairPkg *KeyPairPackage, groupSecrets *groupSecrets, options *groupFromSecretsOptions) (*Group, error) {
 269  	if options == nil {
 270  		options = new(groupFromSecretsOptions)
 271  	}
 272  
 273  	pskSecret, err := extractPSKSecret(welcome.cipherSuite, groupSecrets.psks, options.psks)
 274  	if err != nil {
 275  		return nil, fmt.Errorf("failed to extract PSK secret: %v", err)
 276  	}
 277  
 278  	groupInfo, err := welcome.decryptGroupInfo(groupSecrets.joinerSecret, pskSecret)
 279  	if err != nil {
 280  		return nil, fmt.Errorf("failed to decrypt group info: %v", err)
 281  	}
 282  
 283  	rawTree := options.rawTree
 284  	if rawTree == nil {
 285  		rawTree = findExtensionData(groupInfo.extensions, extensionTypeRatchetTree)
 286  	}
 287  	if rawTree == nil {
 288  		return nil, fmt.Errorf("mls: missing ratchet tree")
 289  	}
 290  
 291  	var tree ratchetTree
 292  	if err := unmarshal(rawTree, &tree); err != nil {
 293  		return nil, fmt.Errorf("failed to unmarshal ratchet tree: %v", err)
 294  	}
 295  
 296  	signerNode := tree.getLeaf(groupInfo.signer)
 297  	if signerNode == nil {
 298  		return nil, fmt.Errorf("mls: signer node is blank")
 299  	} else if !groupInfo.verifySignature(signerNode.signatureKey) {
 300  		return nil, fmt.Errorf("mls: failed to verify signer node signature")
 301  	}
 302  	if !groupInfo.verifyConfirmationTag(groupSecrets.joinerSecret, pskSecret) {
 303  		return nil, fmt.Errorf("mls: failed to verify confirmation tag")
 304  	}
 305  	if groupInfo.groupContext.cipherSuite != welcome.cipherSuite {
 306  		return nil, fmt.Errorf("mls: group info cipher suite doesn't match key package")
 307  	}
 308  
 309  	if err := tree.verifyIntegrity(&groupInfo.groupContext, options.now); err != nil {
 310  		return nil, fmt.Errorf("failed to verify ratchet tree integrity: %v", err)
 311  	}
 312  
 313  	// TODO: perform other group info verification steps
 314  
 315  	groupCtx := groupInfo.groupContext
 316  
 317  	epochSecret, err := groupCtx.extractEpochSecret(groupSecrets.joinerSecret, pskSecret)
 318  	if err != nil {
 319  		return nil, fmt.Errorf("failed to extract epoch secret: %v", err)
 320  	}
 321  
 322  	initSecret, err := groupCtx.cipherSuite.deriveSecret(epochSecret, secretLabelInit)
 323  	if err != nil {
 324  		return nil, fmt.Errorf("failed to derive init secret: %v", err)
 325  	}
 326  
 327  	interimTranscriptHash, err := nextInterimTranscriptHash(groupCtx.cipherSuite, groupCtx.confirmedTranscriptHash, groupInfo.confirmationTag)
 328  	if err != nil {
 329  		return nil, fmt.Errorf("failed to compute next interim transcript hash: %v", err)
 330  	}
 331  
 332  	myLeafIndex, ok := tree.findLeaf(&keyPairPkg.Public.leafNode)
 333  	if !ok {
 334  		return nil, fmt.Errorf("mls: failed to find my leaf node in ratchet tree")
 335  	}
 336  
 337  	privTree := make([]hpkePrivateKey, len(tree))
 338  	privTree[int(myLeafIndex.nodeIndex())] = keyPairPkg.Private.EncryptionKey
 339  
 340  	if groupSecrets.pathSecret != nil {
 341  		nodeIndex := commonAncestor(myLeafIndex.nodeIndex(), groupInfo.signer.nodeIndex())
 342  		err := processPathSecret(groupCtx.cipherSuite, tree, privTree, groupSecrets.pathSecret, nodeIndex)
 343  		if err != nil {
 344  			return nil, fmt.Errorf("failed to process path secret: %v", err)
 345  		}
 346  	}
 347  
 348  	return &Group{
 349  		tree:                  tree,
 350  		groupContext:          groupCtx,
 351  		interimTranscriptHash: interimTranscriptHash,
 352  		pskSecret:             pskSecret,
 353  		epochSecret:           epochSecret,
 354  		initSecret:            initSecret,
 355  		myLeafIndex:           myLeafIndex,
 356  		privTree:              privTree,
 357  		signaturePriv:         keyPairPkg.Private.SignatureKey,
 358  	}, nil
 359  }
 360  
 361  func processPathSecret(cs CipherSuite, tree ratchetTree, privTree []hpkePrivateKey, pathSecret []byte, nodeIndex nodeIndex) error {
 362  	nodePriv, err := nodePrivFromPathSecret(cs, pathSecret, tree.get(nodeIndex).encryptionKey())
 363  	if err != nil {
 364  		return fmt.Errorf("failed to derive node %v private key from path secret: %v", nodeIndex, err)
 365  	}
 366  	privTree[int(nodeIndex)] = nodePriv
 367  
 368  	for {
 369  		var ok bool
 370  		nodeIndex, ok = tree.numLeaves().parent(nodeIndex)
 371  		if !ok {
 372  			break
 373  		}
 374  
 375  		pathSecret, err := cs.deriveSecret(pathSecret, []byte("path"))
 376  		if err != nil {
 377  			return fmt.Errorf("failed to derive path secret: %v", err)
 378  		}
 379  
 380  		nodePriv, err := nodePrivFromPathSecret(cs, pathSecret, tree.get(nodeIndex).encryptionKey())
 381  		if err != nil {
 382  			return fmt.Errorf("failed to derive node %v private key from path secret: %v", nodeIndex, err)
 383  		}
 384  		privTree[int(nodeIndex)] = nodePriv
 385  	}
 386  
 387  	return nil
 388  }
 389  
 390  // UnmarshalAndProcessMessage decodes a raw MLS message intended for the group
 391  // and processes it.
 392  //
 393  // If the MLS message contains encrypted application data, the decrypted data
 394  // is returned.
 395  func (group *Group) UnmarshalAndProcessMessage(raw []byte) (plaintext []byte, selfSent bool, err error) {
 396  	var msg mlsMessage
 397  	if err := unmarshal([]byte(raw), &msg); err != nil {
 398  		return nil, false, fmt.Errorf("failed to unmarshal MLS message: %v", err)
 399  	}
 400  
 401  	switch msg.wireFormat {
 402  	case wireFormatMLSPublicMessage:
 403  		return nil, false, group.processPublicMessage(msg.publicMessage)
 404  	case wireFormatMLSPrivateMessage:
 405  		return group.processPrivateMessage(msg.privateMessage)
 406  	default:
 407  		// TODO: support other wire formats
 408  		return nil, false, fmt.Errorf("mls: unsupported wire format: %v", msg.wireFormat)
 409  	}
 410  }
 411  
 412  func (group *Group) processPublicMessage(pubMsg *publicMessage) error {
 413  	authContent, err := group.verifyPublicMessage(pubMsg)
 414  	if err != nil {
 415  		return fmt.Errorf("failed to verify public message: %v", err)
 416  	}
 417  
 418  	switch authContent.content.contentType {
 419  	case contentTypeProposal:
 420  		return group.processProposal(authContent)
 421  	case contentTypeCommit:
 422  		return group.processCommit(authContent, nil, nil, nil)
 423  	case contentTypeApplication:
 424  		return fmt.Errorf("mls: application content type must be encrypted")
 425  	default:
 426  		// TODO: support other content types
 427  		return fmt.Errorf("mls: unsupported content type: %v", authContent.content.contentType)
 428  	}
 429  }
 430  
 431  func (group *Group) verifyPublicMessage(pubMsg *publicMessage) (*authenticatedContent, error) {
 432  	if !pubMsg.content.groupID.Equal(group.groupContext.groupID) {
 433  		return nil, fmt.Errorf("mls: message group ID mismatch")
 434  	}
 435  	if pubMsg.content.epoch != group.groupContext.epoch {
 436  		return nil, fmt.Errorf("mls: epoch mismatch: got %v, want %v", pubMsg.content.epoch, group.groupContext.epoch)
 437  	}
 438  
 439  	if pubMsg.content.sender.senderType != senderTypeMember {
 440  		// TODO: support other sender types
 441  		return nil, fmt.Errorf("mls: unsupported sender type: %v", pubMsg.content.sender.senderType)
 442  	}
 443  	senderLeafIndex := pubMsg.content.sender.leafIndex
 444  	// TODO: check tree length
 445  	senderNode := group.tree.getLeaf(senderLeafIndex)
 446  	if senderNode == nil {
 447  		return nil, fmt.Errorf("mls: blank leaf node for sender")
 448  	}
 449  
 450  	authContent := pubMsg.authenticatedContent()
 451  	if !authContent.verifySignature([]byte(senderNode.signatureKey), &group.groupContext) {
 452  		return nil, fmt.Errorf("mls: failed to verify public message signature")
 453  	}
 454  
 455  	membershipKey, err := group.groupContext.cipherSuite.deriveSecret(group.epochSecret, secretLabelMembership)
 456  	if err != nil {
 457  		return nil, fmt.Errorf("failed to derive membership key: %v", err)
 458  	} else if !pubMsg.verifyMembershipTag(membershipKey, &group.groupContext) {
 459  		return nil, fmt.Errorf("failed to verify membership tag")
 460  	}
 461  
 462  	return authContent, nil
 463  }
 464  
 465  func (group *Group) processPrivateMessage(privMsg *privateMessage) ([]byte, bool, error) {
 466  	cs := group.groupContext.cipherSuite
 467  
 468  	if !privMsg.groupID.Equal(group.groupContext.groupID) {
 469  		return nil, false, fmt.Errorf("mls: message group ID mismatch")
 470  	}
 471  	if privMsg.epoch != group.groupContext.epoch {
 472  		return nil, false, fmt.Errorf("mls: epoch mismatch: got %v, want %v", privMsg.epoch, group.groupContext.epoch)
 473  	}
 474  
 475  	senderDataSecret, err := cs.deriveSecret(group.epochSecret, secretLabelSenderData)
 476  	if err != nil {
 477  		return nil, false, fmt.Errorf("failed to derive sender data secret: %v", err)
 478  	}
 479  
 480  	senderData, err := privMsg.decryptSenderData(cs, senderDataSecret)
 481  	if err != nil {
 482  		return nil, false, fmt.Errorf("failed to decrypt sender data: %v", err)
 483  	}
 484  
 485  	encryptionSecret, err := cs.deriveSecret(group.epochSecret, secretLabelEncryption)
 486  	if err != nil {
 487  		return nil, false, fmt.Errorf("failed to derive encryption secret: %v", err)
 488  	}
 489  
 490  	secretTree, err := deriveSecretTree(cs, group.tree.numLeaves(), encryptionSecret)
 491  	if err != nil {
 492  		return nil, false, fmt.Errorf("failed to erive secret tree: %v", err)
 493  	}
 494  
 495  	label := ratchetLabelFromContentType(privMsg.contentType)
 496  	secret, err := secretTree.deriveRatchetRoot(cs, senderData.leafIndex.nodeIndex(), label)
 497  	if err != nil {
 498  		return nil, false, fmt.Errorf("failed to derive secret ratchet tree root: %v", err)
 499  	}
 500  
 501  	// TODO: limit number of iterations
 502  	// TODO: erase knowledge about used generations to ensure forward secrecy
 503  	for secret.generation != senderData.generation {
 504  		secret, err = secret.deriveNext(cs)
 505  		if err != nil {
 506  			return nil, false, fmt.Errorf("failed to derive next ratchet secret: %v", err)
 507  		}
 508  	}
 509  
 510  	privContent, err := privMsg.decryptContent(cs, secret, senderData.reuseGuard)
 511  	if err != nil {
 512  		return nil, false, fmt.Errorf("failed to decrypt private message content: %v", err)
 513  	}
 514  
 515  	signerNode := group.tree.getLeaf(senderData.leafIndex)
 516  	if signerNode == nil {
 517  		return nil, false, fmt.Errorf("mls: signer node is blank")
 518  	}
 519  
 520  	authContent := privMsg.authenticatedContent(senderData, privContent)
 521  	if !authContent.verifySignature(signerNode.signatureKey, &group.groupContext) {
 522  		return nil, false, fmt.Errorf("failed to verify private message content signature: %v", err)
 523  	}
 524  
 525  	selfSent := senderData.leafIndex == group.myLeafIndex
 526  
 527  	switch authContent.content.contentType {
 528  	case contentTypeProposal:
 529  		return nil, false, group.processProposal(authContent)
 530  	case contentTypeCommit:
 531  		return nil, false, group.processCommit(authContent, nil, nil, nil)
 532  	case contentTypeApplication:
 533  		return authContent.content.applicationData, selfSent, nil
 534  	default:
 535  		// TODO: support other content types
 536  		return nil, false, fmt.Errorf("mls: unsupported content type: %v", authContent.content.contentType)
 537  	}
 538  }
 539  
 540  func (group *Group) processProposal(authContent *authenticatedContent) error {
 541  	if authContent.content.contentType != contentTypeProposal {
 542  		panic("mls: expected a proposal")
 543  	}
 544  	proposal := authContent.content.proposal
 545  
 546  	ref, err := authContent.generateProposalRef(group.groupContext.cipherSuite)
 547  	if err != nil {
 548  		return fmt.Errorf("failed to generate proposal ref: %v", err)
 549  	}
 550  
 551  	group.pendingProposals = append(group.pendingProposals, pendingProposal{
 552  		ref:      ref,
 553  		proposal: proposal,
 554  		sender:   authContent.content.sender.leafIndex,
 555  	})
 556  	return nil
 557  }
 558  
 559  func (group *Group) processCommit(authContent *authenticatedContent, pskIDs []preSharedKeyID, psks [][]byte, now func() time.Time) error {
 560  	cs := group.groupContext.cipherSuite
 561  	senderLeafIndex := authContent.content.sender.leafIndex
 562  
 563  	if authContent.content.contentType != contentTypeCommit {
 564  		panic("mls: expected a commit")
 565  	}
 566  	commit := authContent.content.commit
 567  
 568  	proposals, senders, err := resolveProposals(commit.proposals, senderLeafIndex, group.pendingProposals)
 569  	if err != nil {
 570  		return err
 571  	}
 572  
 573  	if err := verifyProposalList(proposals, senders, senderLeafIndex); err != nil {
 574  		return fmt.Errorf("failed to verify proposals: %v", err)
 575  	}
 576  
 577  	for _, prop := range proposals {
 578  		if prop.proposalType == proposalTypeAdd {
 579  			if err := prop.add.keyPackage.verify(&group.groupContext); err != nil {
 580  				return fmt.Errorf("failed to verify add proposal: %v", err)
 581  			}
 582  		}
 583  	}
 584  
 585  	// TODO: additional proposal list checks
 586  
 587  	if proposalListNeedsPath(proposals) && commit.path == nil {
 588  		return fmt.Errorf("mls: commit is missing update path but required by proposal list")
 589  	}
 590  
 591  	newGroupCtx := group.groupContext
 592  	newGroupCtx.epoch++
 593  
 594  	newTree := group.tree.copy()
 595  	newTree.apply(proposals, senders)
 596  
 597  	newPrivTree := make([]hpkePrivateKey, len(newTree))
 598  	for i := range group.tree {
 599  		if i < len(newPrivTree) {
 600  			newPrivTree[i] = group.privTree[i]
 601  		}
 602  	}
 603  
 604  	_, kdf, _ := cs.hpke().Params()
 605  	commitSecret := make([]byte, kdf.ExtractSize())
 606  	if commit.path != nil {
 607  		if commit.path.leafNode.leafNodeSource != leafNodeSourceCommit {
 608  			return fmt.Errorf("mls: commit path leaf node source must be commit")
 609  		}
 610  
 611  		// TODO: check tree length
 612  		senderNode := newTree.getLeaf(senderLeafIndex)
 613  
 614  		// The same signature key can be re-used, but the encryption key
 615  		// must change
 616  		signatureKeys, encryptionKeys := newTree.keys()
 617  		delete(signatureKeys, string(senderNode.signatureKey))
 618  		err := commit.path.leafNode.verify(&leafNodeVerifyOptions{
 619  			cipherSuite:    cs,
 620  			groupID:        group.groupContext.groupID,
 621  			leafIndex:      senderLeafIndex,
 622  			supportedCreds: newTree.supportedCreds(),
 623  			signatureKeys:  signatureKeys,
 624  			encryptionKeys: encryptionKeys,
 625  			now:            now,
 626  		})
 627  		if err != nil {
 628  			return fmt.Errorf("failed to verify leaf node: %v", err)
 629  		}
 630  
 631  		for _, updateNode := range commit.path.nodes {
 632  			if _, dup := encryptionKeys[string(updateNode.encryptionKey)]; dup {
 633  				return fmt.Errorf("mls: encryption key in update path already used in ratchet tree")
 634  			}
 635  		}
 636  
 637  		if err := newTree.mergeUpdatePath(cs, senderLeafIndex, commit.path); err != nil {
 638  			return fmt.Errorf("failed to merge update path in ratchet tree: %v", err)
 639  		}
 640  
 641  		newGroupCtx.treeHash, err = newTree.computeRootTreeHash(cs)
 642  		if err != nil {
 643  			return fmt.Errorf("failed to compute root tree hash: %v", err)
 644  		}
 645  
 646  		// TODO: update group context extensions
 647  
 648  		commitSecret, err = newTree.decryptPathSecrets(cs, &newGroupCtx, senderLeafIndex, group.myLeafIndex, commit.path, newPrivTree)
 649  		if err != nil {
 650  			return fmt.Errorf("failed to decrypt path secrets: %v", err)
 651  		}
 652  	} else {
 653  		// TODO: only recompute parts of the tree affected by proposals
 654  		newGroupCtx.treeHash, err = newTree.computeRootTreeHash(cs)
 655  		if err != nil {
 656  			return fmt.Errorf("failed to compute root tree hash: %v", err)
 657  		}
 658  	}
 659  
 660  	newGroupCtx.confirmedTranscriptHash, err = authContent.confirmedTranscriptHashInput().hash(cs, group.interimTranscriptHash)
 661  	if err != nil {
 662  		return fmt.Errorf("failed to hash confirmed transcript hash input: %v", err)
 663  	}
 664  
 665  	newInterimTranscriptHash, err := nextInterimTranscriptHash(cs, newGroupCtx.confirmedTranscriptHash, authContent.auth.confirmationTag)
 666  	if err != nil {
 667  		return fmt.Errorf("failed to compute next interim transcript hash: %v", err)
 668  	}
 669  
 670  	newJoinerSecret, err := newGroupCtx.extractJoinerSecret(group.initSecret, commitSecret)
 671  	if err != nil {
 672  		return fmt.Errorf("failed to extract joined secret: %v", err)
 673  	}
 674  
 675  	newPSKSecret, err := extractPSKSecret(cs, pskIDs, psks)
 676  	if err != nil {
 677  		return fmt.Errorf("failed to extract PSK secret: %v", err)
 678  	}
 679  
 680  	newEpochSecret, err := newGroupCtx.extractEpochSecret(newJoinerSecret, newPSKSecret)
 681  	if err != nil {
 682  		return fmt.Errorf("failed to extract epoch secret: %v", err)
 683  	}
 684  
 685  	newInitSecret, err := cs.deriveSecret(newEpochSecret, secretLabelInit)
 686  	if err != nil {
 687  		return fmt.Errorf("failed to erive init secret: %v", err)
 688  	}
 689  
 690  	group.tree = newTree
 691  	group.privTree = newPrivTree
 692  	group.groupContext = newGroupCtx
 693  	group.interimTranscriptHash = newInterimTranscriptHash
 694  	group.pskSecret = newPSKSecret
 695  	group.epochSecret = newEpochSecret
 696  	group.initSecret = newInitSecret
 697  	group.pendingProposals = nil // TODO: only clear proposals we've consumed
 698  	return nil
 699  }
 700  
 701  func resolveProposals(proposalOrRefs []proposalOrRef, senderLeafIndex leafIndex, pendingProposals []pendingProposal) ([]proposal, []leafIndex, error) {
 702  	var (
 703  		proposals []proposal
 704  		senders   []leafIndex
 705  	)
 706  	for _, propOrRef := range proposalOrRefs {
 707  		switch propOrRef.typ {
 708  		case proposalOrRefTypeProposal:
 709  			proposals = append(proposals, *propOrRef.proposal)
 710  			senders = append(senders, senderLeafIndex)
 711  		case proposalOrRefTypeReference:
 712  			var found bool
 713  			for _, pp := range pendingProposals {
 714  				if pp.ref.Equal(propOrRef.reference) {
 715  					found = true
 716  					proposals = append(proposals, *pp.proposal)
 717  					senders = append(senders, pp.sender)
 718  					break
 719  				}
 720  			}
 721  			if !found {
 722  				return nil, nil, fmt.Errorf("mls: cannot find proposal reference: %v", propOrRef.reference)
 723  			}
 724  		}
 725  	}
 726  
 727  	return proposals, senders, nil
 728  }
 729  
 730  // CreateWelcome creates a new welcome message, inviting new members to the
 731  // group.
 732  //
 733  // The welcome message should be sent to the new members. Alongside the welcome
 734  // message, a raw MLS message is returned and must be consumed by all existing
 735  // members of the group to add the new members.
 736  func (group *Group) CreateWelcome(keyPkgs []KeyPackage) (*Welcome, []byte, error) {
 737  	// TODO: missing steps from section 12.4.1
 738  	cs := group.groupContext.cipherSuite
 739  
 740  	if len(keyPkgs) == 0 {
 741  		panic("mls: expected at least one key package")
 742  	}
 743  
 744  	proposals := make([]proposal, len(keyPkgs))
 745  	proposalOrRefs := make([]proposalOrRef, len(keyPkgs))
 746  	for i, keyPkg := range keyPkgs {
 747  		proposals[i] = proposal{
 748  			proposalType: proposalTypeAdd,
 749  			add:          &add{keyPackage: keyPkg},
 750  		}
 751  		proposalOrRefs[i] = proposalOrRef{
 752  			typ:      proposalOrRefTypeProposal,
 753  			proposal: &proposals[i],
 754  		}
 755  	}
 756  
 757  	// TODO: check proposal list validity per section 12.2
 758  	commit := commit{proposals: proposalOrRefs}
 759  
 760  	newGroupCtx := group.groupContext
 761  	newGroupCtx.epoch++
 762  
 763  	newTree := group.tree.copy()
 764  	newTree.apply(proposals, []leafIndex{group.myLeafIndex})
 765  
 766  	// TODO: only recompute parts of the tree affected by proposals
 767  	var err error
 768  	newGroupCtx.treeHash, err = newTree.computeRootTreeHash(cs)
 769  	if err != nil {
 770  		return nil, nil, fmt.Errorf("failed to compute root tree hash: %v", err)
 771  	}
 772  
 773  	_, kdf, _ := cs.hpke().Params()
 774  	commitSecret := make([]byte, kdf.ExtractSize())
 775  
 776  	pskSecret, err := extractPSKSecret(cs, nil, nil)
 777  	if err != nil {
 778  		return nil, nil, fmt.Errorf("failed to extract PSK secret: %v", err)
 779  	}
 780  
 781  	framedContent := framedContent{
 782  		groupID: group.groupContext.groupID,
 783  		epoch:   group.groupContext.epoch,
 784  		sender: sender{
 785  			senderType: senderTypeMember,
 786  			leafIndex:  group.myLeafIndex,
 787  		},
 788  		contentType: contentTypeCommit,
 789  		commit:      &commit,
 790  	}
 791  
 792  	public := false // TODO: add option to enable this
 793  	var (
 794  		authContent *authenticatedContent
 795  		authData    *framedContentAuthData
 796  		pubMsg      *publicMessage
 797  		privContent *privateMessageContent
 798  	)
 799  	if public {
 800  		pubMsg, err = signPublicMessage(cs, group.signaturePriv, &framedContent, &group.groupContext)
 801  		if err != nil {
 802  			return nil, nil, fmt.Errorf("failed to sign public message: %v", err)
 803  		}
 804  		authContent = pubMsg.authenticatedContent()
 805  		authData = &pubMsg.auth
 806  	} else {
 807  		privContent, err = signPrivateMessageContent(cs, group.signaturePriv, &framedContent, &group.groupContext)
 808  		if err != nil {
 809  			return nil, nil, fmt.Errorf("failed to sign private message: %v", err)
 810  		}
 811  		authContent = privContent.authenticatedContent(&framedContent)
 812  		authData = &privContent.auth
 813  	}
 814  
 815  	newGroupCtx.confirmedTranscriptHash, err = authContent.confirmedTranscriptHashInput().hash(cs, group.interimTranscriptHash)
 816  	if err != nil {
 817  		return nil, nil, fmt.Errorf("failed to hash confirmed transcript hash input: %v", err)
 818  	}
 819  
 820  	joinerSecret, err := newGroupCtx.extractJoinerSecret(group.initSecret, commitSecret)
 821  	if err != nil {
 822  		return nil, nil, fmt.Errorf("failed to extract joiner secret: %v", err)
 823  	}
 824  
 825  	epochSecret, err := newGroupCtx.extractEpochSecret(joinerSecret, pskSecret)
 826  	if err != nil {
 827  		return nil, nil, fmt.Errorf("failed to extract epoch secret: %v", err)
 828  	}
 829  
 830  	confirmationTag, err := newGroupCtx.signConfirmationTag(epochSecret)
 831  	if err != nil {
 832  		return nil, nil, fmt.Errorf("failed to sign confirmation tag: %v", err)
 833  	}
 834  	authData.confirmationTag = confirmationTag
 835  
 836  	rawTree, err := marshal(newTree)
 837  	if err != nil {
 838  		return nil, nil, fmt.Errorf("failed to marshal ratchet tree: %v", err)
 839  	}
 840  
 841  	newGroupInfo := groupInfo{
 842  		groupContext:    newGroupCtx,
 843  		confirmationTag: confirmationTag,
 844  		signer:          group.myLeafIndex,
 845  		extensions: []extension{
 846  			{
 847  				extensionType: extensionTypeRatchetTree,
 848  				extensionData: rawTree,
 849  			},
 850  		},
 851  	}
 852  	if err := newGroupInfo.sign(group.signaturePriv); err != nil {
 853  		return nil, nil, fmt.Errorf("failed to sign group info: %v", err)
 854  	}
 855  
 856  	encryptedGroupInfo, err := newGroupInfo.encrypt(joinerSecret, pskSecret)
 857  	if err != nil {
 858  		return nil, nil, fmt.Errorf("failed to encrypt group info: %v", err)
 859  	}
 860  
 861  	groupSecrets := groupSecrets{joinerSecret: joinerSecret}
 862  	encGroupSecrets := make([]encryptedGroupSecrets, len(keyPkgs))
 863  	for i, keyPkg := range keyPkgs {
 864  		keyPkgRef, err := keyPkg.GenerateRef()
 865  		if err != nil {
 866  			return nil, nil, fmt.Errorf("failed to generate key package ref: %v", err)
 867  		}
 868  
 869  		rawEncryptedGroupSecrets, err := groupSecrets.encrypt(cs, keyPkg.initKey, encryptedGroupInfo)
 870  		if err != nil {
 871  			return nil, nil, fmt.Errorf("failed to encrypt group secrets: %v", err)
 872  		}
 873  
 874  		encGroupSecrets[i] = encryptedGroupSecrets{
 875  			newMember:             keyPkgRef,
 876  			encryptedGroupSecrets: *rawEncryptedGroupSecrets,
 877  		}
 878  	}
 879  
 880  	var rawMsg []byte
 881  	if public {
 882  		rawMsg, err = group.signPublicMessageMembershipTag(pubMsg)
 883  		if err != nil {
 884  			return nil, nil, err
 885  		}
 886  	} else {
 887  		rawMsg, err = group.encryptPrivateMessage(&framedContent, privContent)
 888  		if err != nil {
 889  			return nil, nil, fmt.Errorf("failed to encrypt private message: %v", err)
 890  		}
 891  	}
 892  
 893  	return &Welcome{
 894  		cipherSuite:        cs,
 895  		secrets:            encGroupSecrets,
 896  		encryptedGroupInfo: encryptedGroupInfo,
 897  	}, rawMsg, nil
 898  }
 899  
 900  // CreateApplicationMessage creates a new encrypted application message for the
 901  // group. The message contains an arbitrary application-specific payload.
 902  func (group *Group) CreateApplicationMessage(data []byte) ([]byte, error) {
 903  	cs := group.groupContext.cipherSuite
 904  
 905  	framedContent := framedContent{
 906  		groupID: group.groupContext.groupID,
 907  		epoch:   group.groupContext.epoch,
 908  		sender: sender{
 909  			senderType: senderTypeMember,
 910  			leafIndex:  group.myLeafIndex,
 911  		},
 912  		contentType:     contentTypeApplication,
 913  		applicationData: data,
 914  	}
 915  	privContent, err := signPrivateMessageContent(cs, group.signaturePriv, &framedContent, &group.groupContext)
 916  	if err != nil {
 917  		return nil, fmt.Errorf("failed to sign private message: %v", err)
 918  	}
 919  
 920  	return group.encryptPrivateMessage(&framedContent, privContent)
 921  }
 922  
 923  func (group *Group) encryptPrivateMessage(framedContent *framedContent, privContent *privateMessageContent) ([]byte, error) {
 924  	cs := group.groupContext.cipherSuite
 925  
 926  	senderData, err := newSenderData(group.myLeafIndex, 0) // TODO: set generation > 0
 927  	if err != nil {
 928  		return nil, fmt.Errorf("failed to create sender data: %v", err)
 929  	}
 930  
 931  	encryptionSecret, err := cs.deriveSecret(group.epochSecret, secretLabelEncryption)
 932  	if err != nil {
 933  		return nil, fmt.Errorf("failed to derive encryption secret: %v", err)
 934  	}
 935  
 936  	secretTree, err := deriveSecretTree(cs, group.tree.numLeaves(), encryptionSecret)
 937  	if err != nil {
 938  		return nil, fmt.Errorf("failed to erive secret tree: %v", err)
 939  	}
 940  
 941  	label := ratchetLabelFromContentType(framedContent.contentType)
 942  	secret, err := secretTree.deriveRatchetRoot(cs, group.myLeafIndex.nodeIndex(), label)
 943  	if err != nil {
 944  		return nil, fmt.Errorf("failed to derive secret ratchet tree root: %v", err)
 945  	}
 946  
 947  	senderDataSecret, err := cs.deriveSecret(group.epochSecret, secretLabelSenderData)
 948  	if err != nil {
 949  		return nil, fmt.Errorf("failed to derive sender data secret: %v", err)
 950  	}
 951  
 952  	privMsg, err := encryptPrivateMessage(cs, secret, senderDataSecret, framedContent, privContent, senderData)
 953  	if err != nil {
 954  		return nil, fmt.Errorf("failed to encrypt private message: %v", err)
 955  	}
 956  
 957  	rawMsg, err := marshal(&mlsMessage{
 958  		version:        protocolVersionMLS10,
 959  		wireFormat:     wireFormatMLSPrivateMessage,
 960  		privateMessage: privMsg,
 961  	})
 962  	if err != nil {
 963  		return nil, fmt.Errorf("failed to marshal private message: %v", err)
 964  	}
 965  
 966  	return rawMsg, nil
 967  }
 968  
 969  func (group *Group) signPublicMessageMembershipTag(pubMsg *publicMessage) ([]byte, error) {
 970  	cs := group.groupContext.cipherSuite
 971  
 972  	membershipKey, err := group.groupContext.cipherSuite.deriveSecret(group.epochSecret, secretLabelMembership)
 973  	if err != nil {
 974  		return nil, fmt.Errorf("failed to derive membership key: %v", err)
 975  	}
 976  	if err := pubMsg.signMembershipTag(cs, membershipKey, &group.groupContext); err != nil {
 977  		return nil, fmt.Errorf("failed to sign public message membership tag: %v", err)
 978  	}
 979  
 980  	rawMsg, err := marshal(&mlsMessage{
 981  		version:       protocolVersionMLS10,
 982  		wireFormat:    wireFormatMLSPublicMessage,
 983  		publicMessage: pubMsg,
 984  	})
 985  	if err != nil {
 986  		return nil, fmt.Errorf("failed to marshal public message: %v", err)
 987  	}
 988  
 989  	return rawMsg, nil
 990  }
 991  
 992  type commit struct {
 993  	proposals []proposalOrRef
 994  	path      *updatePath // optional
 995  }
 996  
 997  func (c *commit) unmarshal(s *cryptobyte.String) error {
 998  	*c = commit{}
 999  
1000  	err := readVector(s, func(s *cryptobyte.String) error {
1001  		var propOrRef proposalOrRef
1002  		if err := propOrRef.unmarshal(s); err != nil {
1003  			return err
1004  		}
1005  		c.proposals = append(c.proposals, propOrRef)
1006  		return nil
1007  	})
1008  	if err != nil {
1009  		return err
1010  	}
1011  
1012  	var hasPath bool
1013  	if !readOptional(s, &hasPath) {
1014  		return io.ErrUnexpectedEOF
1015  	} else if hasPath {
1016  		c.path = new(updatePath)
1017  		if err := c.path.unmarshal(s); err != nil {
1018  			return err
1019  		}
1020  	}
1021  
1022  	return nil
1023  }
1024  
1025  func (c *commit) marshal(b *cryptobyte.Builder) {
1026  	writeVector(b, len(c.proposals), func(b *cryptobyte.Builder, i int) {
1027  		c.proposals[i].marshal(b)
1028  	})
1029  	writeOptional(b, c.path != nil)
1030  	if c.path != nil {
1031  		c.path.marshal(b)
1032  	}
1033  }
1034  
1035  type groupInfo struct {
1036  	groupContext    groupContext
1037  	extensions      []extension
1038  	confirmationTag []byte
1039  	signer          leafIndex
1040  	signature       []byte
1041  }
1042  
1043  func (info *groupInfo) unmarshal(s *cryptobyte.String) error {
1044  	*info = groupInfo{}
1045  
1046  	if err := info.groupContext.unmarshal(s); err != nil {
1047  		return err
1048  	}
1049  
1050  	exts, err := unmarshalExtensionVec(s)
1051  	if err != nil {
1052  		return err
1053  	}
1054  	info.extensions = exts
1055  
1056  	if !readOpaqueVec(s, &info.confirmationTag) || !s.ReadUint32((*uint32)(&info.signer)) || !readOpaqueVec(s, &info.signature) {
1057  		return err
1058  	}
1059  
1060  	return nil
1061  }
1062  
1063  func (info *groupInfo) marshal(b *cryptobyte.Builder) {
1064  	(*groupInfoTBS)(info).marshal(b)
1065  	writeOpaqueVec(b, info.signature)
1066  }
1067  
1068  func (info *groupInfo) verifySignature(signerPub signaturePublicKey) bool {
1069  	cs := info.groupContext.cipherSuite
1070  	tbs, err := marshal((*groupInfoTBS)(info))
1071  	if err != nil {
1072  		return false
1073  	}
1074  	return cs.verifyWithLabel(signerPub, []byte("GroupInfoTBS"), tbs, info.signature)
1075  }
1076  
1077  func (info *groupInfo) sign(signerPriv signaturePrivateKey) error {
1078  	cs := info.groupContext.cipherSuite
1079  	tbs, err := marshal((*groupInfoTBS)(info))
1080  	if err != nil {
1081  		return err
1082  	}
1083  	sig, err := cs.signWithLabel(signerPriv, []byte("GroupInfoTBS"), tbs)
1084  	if err != nil {
1085  		return err
1086  	}
1087  	info.signature = sig
1088  	return nil
1089  }
1090  
1091  func (info *groupInfo) verifyConfirmationTag(joinerSecret, pskSecret []byte) bool {
1092  	cs := info.groupContext.cipherSuite
1093  	epochSecret, err := info.groupContext.extractEpochSecret(joinerSecret, pskSecret)
1094  	if err != nil {
1095  		return false
1096  	}
1097  	confirmationKey, err := cs.deriveSecret(epochSecret, secretLabelConfirm)
1098  	if err != nil {
1099  		return false
1100  	}
1101  	return cs.verifyMAC(confirmationKey, info.groupContext.confirmedTranscriptHash, info.confirmationTag)
1102  }
1103  
1104  func (info *groupInfo) encrypt(joinerSecret, pskSecret []byte) ([]byte, error) {
1105  	cs := info.groupContext.cipherSuite
1106  	_, _, aead := cs.hpke().Params()
1107  
1108  	welcomeSecret, err := extractWelcomeSecret(cs, joinerSecret, pskSecret)
1109  	if err != nil {
1110  		return nil, err
1111  	}
1112  
1113  	welcomeNonce, err := cs.expandWithLabel(welcomeSecret, []byte("nonce"), nil, uint16(aead.NonceSize()))
1114  	if err != nil {
1115  		return nil, err
1116  	}
1117  	welcomeKey, err := cs.expandWithLabel(welcomeSecret, []byte("key"), nil, uint16(aead.KeySize()))
1118  	if err != nil {
1119  		return nil, err
1120  	}
1121  
1122  	cipher, err := aead.New(welcomeKey)
1123  	if err != nil {
1124  		return nil, err
1125  	}
1126  
1127  	rawGroupInfo, err := marshal(info)
1128  	if err != nil {
1129  		return nil, err
1130  	}
1131  
1132  	return cipher.Seal(nil, welcomeNonce, rawGroupInfo, nil), nil
1133  }
1134  
1135  type groupInfoTBS groupInfo
1136  
1137  func (info *groupInfoTBS) marshal(b *cryptobyte.Builder) {
1138  	info.groupContext.marshal(b)
1139  	marshalExtensionVec(b, info.extensions)
1140  	writeOpaqueVec(b, info.confirmationTag)
1141  	b.AddUint32(uint32(info.signer))
1142  }
1143  
1144  type groupSecrets struct {
1145  	joinerSecret []byte
1146  	pathSecret   []byte // optional
1147  	psks         []preSharedKeyID
1148  }
1149  
1150  func (sec *groupSecrets) unmarshal(s *cryptobyte.String) error {
1151  	*sec = groupSecrets{}
1152  
1153  	if !readOpaqueVec(s, &sec.joinerSecret) {
1154  		return io.ErrUnexpectedEOF
1155  	}
1156  
1157  	var hasPathSecret bool
1158  	if !readOptional(s, &hasPathSecret) {
1159  		return io.ErrUnexpectedEOF
1160  	} else if hasPathSecret && !readOpaqueVec(s, &sec.pathSecret) {
1161  		return io.ErrUnexpectedEOF
1162  	}
1163  
1164  	return readVector(s, func(s *cryptobyte.String) error {
1165  		var psk preSharedKeyID
1166  		if err := psk.unmarshal(s); err != nil {
1167  			return err
1168  		}
1169  		sec.psks = append(sec.psks, psk)
1170  		return nil
1171  	})
1172  }
1173  
1174  func (sec *groupSecrets) marshal(b *cryptobyte.Builder) {
1175  	writeOpaqueVec(b, sec.joinerSecret)
1176  
1177  	writeOptional(b, sec.pathSecret != nil)
1178  	if sec.pathSecret != nil {
1179  		writeOpaqueVec(b, sec.pathSecret)
1180  	}
1181  
1182  	writeVector(b, len(sec.psks), func(b *cryptobyte.Builder, i int) {
1183  		sec.psks[i].marshal(b)
1184  	})
1185  }
1186  
1187  // verifySingleReInitOrBranchPSK verifies that at most one key has type
1188  // resumption with usage reinit or branch.
1189  func (sec *groupSecrets) verifySingleReinitOrBranchPSK() bool {
1190  	n := 0
1191  	for _, pskID := range sec.psks {
1192  		if pskID.pskType != pskTypeResumption {
1193  			continue
1194  		}
1195  		switch pskID.usage {
1196  		case resumptionPSKUsageReinit, resumptionPSKUsageBranch:
1197  			n++
1198  		}
1199  	}
1200  	return n <= 1
1201  }
1202  
1203  func (sec *groupSecrets) encrypt(cs CipherSuite, initKey hpkePublicKey, encryptedGroupInfo []byte) (*hpkeCiphertext, error) {
1204  	rawGroupSecrets, err := marshal(sec)
1205  	if err != nil {
1206  		return nil, err
1207  	}
1208  
1209  	kemOutput, ciphertext, err := cs.encryptWithLabel(initKey, []byte("Welcome"), encryptedGroupInfo, rawGroupSecrets)
1210  	if err != nil {
1211  		return nil, err
1212  	}
1213  
1214  	return &hpkeCiphertext{
1215  		kemOutput:  kemOutput,
1216  		ciphertext: ciphertext,
1217  	}, nil
1218  }
1219  
1220  // A Welcome message includes secret keying information necessary to join a
1221  // group.
1222  type Welcome struct {
1223  	cipherSuite        CipherSuite
1224  	secrets            []encryptedGroupSecrets
1225  	encryptedGroupInfo []byte
1226  }
1227  
1228  // UnmarshalWelcome reads a welcome message.
1229  func UnmarshalWelcome(raw []byte) (*Welcome, error) {
1230  	var msg mlsMessage
1231  	if err := unmarshal(raw, &msg); err != nil {
1232  		return nil, err
1233  	} else if msg.wireFormat != wireFormatMLSWelcome {
1234  		return nil, fmt.Errorf("mls: expected a key package message, got wire format %v", msg.wireFormat)
1235  	}
1236  	return msg.welcome, nil
1237  }
1238  
1239  // Bytes encodes the welcome message.
1240  func (w *Welcome) Bytes() []byte {
1241  	raw, err := marshal(&mlsMessage{
1242  		version:    protocolVersionMLS10,
1243  		wireFormat: wireFormatMLSWelcome,
1244  		welcome:    w,
1245  	})
1246  	if err != nil {
1247  		// should never happen
1248  		panic(fmt.Errorf("mls: failed to marshal welcome message: %v", err))
1249  	}
1250  	return raw
1251  }
1252  
1253  func (w *Welcome) unmarshal(s *cryptobyte.String) error {
1254  	*w = Welcome{}
1255  
1256  	if !s.ReadUint16((*uint16)(&w.cipherSuite)) {
1257  		return io.ErrUnexpectedEOF
1258  	}
1259  
1260  	err := readVector(s, func(s *cryptobyte.String) error {
1261  		var sec encryptedGroupSecrets
1262  		if err := sec.unmarshal(s); err != nil {
1263  			return err
1264  		}
1265  		w.secrets = append(w.secrets, sec)
1266  		return nil
1267  	})
1268  	if err != nil {
1269  		return err
1270  	}
1271  
1272  	if !readOpaqueVec(s, &w.encryptedGroupInfo) {
1273  		return io.ErrUnexpectedEOF
1274  	}
1275  
1276  	return nil
1277  }
1278  
1279  func (w *Welcome) marshal(b *cryptobyte.Builder) {
1280  	b.AddUint16(uint16(w.cipherSuite))
1281  	writeVector(b, len(w.secrets), func(b *cryptobyte.Builder, i int) {
1282  		w.secrets[i].marshal(b)
1283  	})
1284  	writeOpaqueVec(b, w.encryptedGroupInfo)
1285  }
1286  
1287  // NewMembers returns the list of key package references this welcome message
1288  // contains secret keying information for.
1289  func (w *Welcome) NewMembers() []KeyPackageRef {
1290  	refs := make([]KeyPackageRef, len(w.secrets))
1291  	for i, sec := range w.secrets {
1292  		refs[i] = sec.newMember
1293  	}
1294  	return refs
1295  }
1296  
1297  func (w *Welcome) findSecret(ref KeyPackageRef) *encryptedGroupSecrets {
1298  	for i, sec := range w.secrets {
1299  		if sec.newMember.Equal(ref) {
1300  			return &w.secrets[i]
1301  		}
1302  	}
1303  	return nil
1304  }
1305  
1306  func (w *Welcome) decryptGroupSecrets(ref KeyPackageRef, initKeyPriv hpkePrivateKey) (*groupSecrets, error) {
1307  	cs := w.cipherSuite
1308  
1309  	sec := w.findSecret(ref)
1310  	if sec == nil {
1311  		return nil, fmt.Errorf("mls: encrypted group secrets not found for provided key package ref")
1312  	}
1313  
1314  	rawGroupSecrets, err := cs.decryptWithLabel(initKeyPriv, []byte("Welcome"), w.encryptedGroupInfo, sec.encryptedGroupSecrets.kemOutput, sec.encryptedGroupSecrets.ciphertext)
1315  	if err != nil {
1316  		return nil, err
1317  	}
1318  	var groupSecrets groupSecrets
1319  	if err := unmarshal(rawGroupSecrets, &groupSecrets); err != nil {
1320  		return nil, err
1321  	}
1322  
1323  	return &groupSecrets, err
1324  }
1325  
1326  func (w *Welcome) decryptGroupInfo(joinerSecret, pskSecret []byte) (*groupInfo, error) {
1327  	cs := w.cipherSuite
1328  	_, _, aead := cs.hpke().Params()
1329  
1330  	welcomeSecret, err := extractWelcomeSecret(cs, joinerSecret, pskSecret)
1331  	if err != nil {
1332  		return nil, err
1333  	}
1334  
1335  	welcomeNonce, err := cs.expandWithLabel(welcomeSecret, []byte("nonce"), nil, uint16(aead.NonceSize()))
1336  	if err != nil {
1337  		return nil, err
1338  	}
1339  	welcomeKey, err := cs.expandWithLabel(welcomeSecret, []byte("key"), nil, uint16(aead.KeySize()))
1340  	if err != nil {
1341  		return nil, err
1342  	}
1343  
1344  	welcomeCipher, err := aead.New(welcomeKey)
1345  	if err != nil {
1346  		return nil, err
1347  	}
1348  	rawGroupInfo, err := welcomeCipher.Open(nil, welcomeNonce, w.encryptedGroupInfo, nil)
1349  	if err != nil {
1350  		return nil, err
1351  	}
1352  
1353  	var groupInfo groupInfo
1354  	if err := unmarshal(rawGroupInfo, &groupInfo); err != nil {
1355  		return nil, err
1356  	}
1357  
1358  	return &groupInfo, nil
1359  }
1360  
1361  type encryptedGroupSecrets struct {
1362  	newMember             KeyPackageRef
1363  	encryptedGroupSecrets hpkeCiphertext
1364  }
1365  
1366  func (sec *encryptedGroupSecrets) unmarshal(s *cryptobyte.String) error {
1367  	*sec = encryptedGroupSecrets{}
1368  	if !readOpaqueVec(s, (*[]byte)(&sec.newMember)) {
1369  		return io.ErrUnexpectedEOF
1370  	}
1371  	if err := sec.encryptedGroupSecrets.unmarshal(s); err != nil {
1372  		return err
1373  	}
1374  	return nil
1375  }
1376  
1377  func (sec *encryptedGroupSecrets) marshal(b *cryptobyte.Builder) {
1378  	writeOpaqueVec(b, []byte(sec.newMember))
1379  	sec.encryptedGroupSecrets.marshal(b)
1380  }
1381