tree.go raw

   1  package mls
   2  
   3  import (
   4  	"bytes"
   5  	"fmt"
   6  	"io"
   7  	"time"
   8  
   9  	"golang.org/x/crypto/cryptobyte"
  10  )
  11  
  12  type parentNode struct {
  13  	encryptionKey  hpkePublicKey
  14  	parentHash     []byte
  15  	unmergedLeaves []leafIndex
  16  }
  17  
  18  func (node *parentNode) unmarshal(s *cryptobyte.String) error {
  19  	*node = parentNode{}
  20  	if !readOpaqueVec(s, (*[]byte)(&node.encryptionKey)) || !readOpaqueVec(s, &node.parentHash) {
  21  		return io.ErrUnexpectedEOF
  22  	}
  23  	return readVector(s, func(s *cryptobyte.String) error {
  24  		var i leafIndex
  25  		if !s.ReadUint32((*uint32)(&i)) {
  26  			return io.ErrUnexpectedEOF
  27  		}
  28  		node.unmergedLeaves = append(node.unmergedLeaves, i)
  29  		return nil
  30  	})
  31  }
  32  
  33  func (node *parentNode) marshal(b *cryptobyte.Builder) {
  34  	writeOpaqueVec(b, []byte(node.encryptionKey))
  35  	writeOpaqueVec(b, node.parentHash)
  36  	writeVector(b, len(node.unmergedLeaves), func(b *cryptobyte.Builder, i int) {
  37  		b.AddUint32(uint32(node.unmergedLeaves[i]))
  38  	})
  39  }
  40  
  41  func (node *parentNode) computeParentHash(cs CipherSuite, originalSiblingTreeHash []byte) ([]byte, error) {
  42  	rawInput, err := marshalParentHashInput(node.encryptionKey, node.parentHash, originalSiblingTreeHash)
  43  	if err != nil {
  44  		return nil, err
  45  	}
  46  	h := cs.hash().New()
  47  	h.Write(rawInput)
  48  	return h.Sum(nil), nil
  49  }
  50  
  51  func marshalParentHashInput(encryptionKey hpkePublicKey, parentHash, originalSiblingTreeHash []byte) ([]byte, error) {
  52  	var b cryptobyte.Builder
  53  	writeOpaqueVec(&b, []byte(encryptionKey))
  54  	writeOpaqueVec(&b, parentHash)
  55  	writeOpaqueVec(&b, originalSiblingTreeHash)
  56  	return b.Bytes()
  57  }
  58  
  59  type leafNodeSource uint8
  60  
  61  const (
  62  	leafNodeSourceKeyPackage leafNodeSource = 1
  63  	leafNodeSourceUpdate     leafNodeSource = 2
  64  	leafNodeSourceCommit     leafNodeSource = 3
  65  )
  66  
  67  func (src *leafNodeSource) unmarshal(s *cryptobyte.String) error {
  68  	if !s.ReadUint8((*uint8)(src)) {
  69  		return io.ErrUnexpectedEOF
  70  	}
  71  	switch *src {
  72  	case leafNodeSourceKeyPackage, leafNodeSourceUpdate, leafNodeSourceCommit:
  73  		return nil
  74  	default:
  75  		return fmt.Errorf("mls: invalid leaf node source %d", *src)
  76  	}
  77  }
  78  
  79  func (src leafNodeSource) marshal(b *cryptobyte.Builder) {
  80  	b.AddUint8(uint8(src))
  81  }
  82  
  83  type capabilities struct {
  84  	versions     []protocolVersion
  85  	cipherSuites []CipherSuite
  86  	extensions   []extensionType
  87  	proposals    []proposalType
  88  	credentials  []credentialType
  89  }
  90  
  91  func (caps *capabilities) unmarshal(s *cryptobyte.String) error {
  92  	*caps = capabilities{}
  93  
  94  	// Note: all unknown values here must be ignored
  95  
  96  	err := readVector(s, func(s *cryptobyte.String) error {
  97  		var ver protocolVersion
  98  		if !s.ReadUint16((*uint16)(&ver)) {
  99  			return io.ErrUnexpectedEOF
 100  		}
 101  		caps.versions = append(caps.versions, ver)
 102  		return nil
 103  	})
 104  	if err != nil {
 105  		return err
 106  	}
 107  
 108  	err = readVector(s, func(s *cryptobyte.String) error {
 109  		var cs CipherSuite
 110  		if !s.ReadUint16((*uint16)(&cs)) {
 111  			return io.ErrUnexpectedEOF
 112  		}
 113  		caps.cipherSuites = append(caps.cipherSuites, cs)
 114  		return nil
 115  	})
 116  	if err != nil {
 117  		return err
 118  	}
 119  
 120  	err = readVector(s, func(s *cryptobyte.String) error {
 121  		var et extensionType
 122  		if !s.ReadUint16((*uint16)(&et)) {
 123  			return io.ErrUnexpectedEOF
 124  		}
 125  		caps.extensions = append(caps.extensions, et)
 126  		return nil
 127  	})
 128  	if err != nil {
 129  		return err
 130  	}
 131  
 132  	err = readVector(s, func(s *cryptobyte.String) error {
 133  		var pt proposalType
 134  		if !s.ReadUint16((*uint16)(&pt)) {
 135  			return io.ErrUnexpectedEOF
 136  		}
 137  		caps.proposals = append(caps.proposals, pt)
 138  		return nil
 139  	})
 140  	if err != nil {
 141  		return err
 142  	}
 143  
 144  	err = readVector(s, func(s *cryptobyte.String) error {
 145  		var ct credentialType
 146  		if !s.ReadUint16((*uint16)(&ct)) {
 147  			return io.ErrUnexpectedEOF
 148  		}
 149  		caps.credentials = append(caps.credentials, ct)
 150  		return nil
 151  	})
 152  	if err != nil {
 153  		return err
 154  	}
 155  
 156  	return nil
 157  }
 158  
 159  func (caps *capabilities) marshal(b *cryptobyte.Builder) {
 160  	writeVector(b, len(caps.versions), func(b *cryptobyte.Builder, i int) {
 161  		b.AddUint16(uint16(caps.versions[i]))
 162  	})
 163  
 164  	writeVector(b, len(caps.cipherSuites), func(b *cryptobyte.Builder, i int) {
 165  		b.AddUint16(uint16(caps.cipherSuites[i]))
 166  	})
 167  
 168  	writeVector(b, len(caps.extensions), func(b *cryptobyte.Builder, i int) {
 169  		b.AddUint16(uint16(caps.extensions[i]))
 170  	})
 171  
 172  	writeVector(b, len(caps.proposals), func(b *cryptobyte.Builder, i int) {
 173  		b.AddUint16(uint16(caps.proposals[i]))
 174  	})
 175  
 176  	writeVector(b, len(caps.credentials), func(b *cryptobyte.Builder, i int) {
 177  		b.AddUint16(uint16(caps.credentials[i]))
 178  	})
 179  }
 180  
 181  const maxLeafNodeLifetime = 3 * 30 * 24 * time.Hour
 182  
 183  type lifetime struct {
 184  	notBefore, notAfter uint64
 185  }
 186  
 187  func newLifetime(notBefore, notAfter time.Time) *lifetime {
 188  	return &lifetime{
 189  		notBefore: uint64(notBefore.Unix()),
 190  		notAfter:  uint64(notAfter.Unix()),
 191  	}
 192  }
 193  
 194  func (lt *lifetime) unmarshal(s *cryptobyte.String) error {
 195  	*lt = lifetime{}
 196  	if !s.ReadUint64(&lt.notBefore) || !s.ReadUint64(&lt.notAfter) {
 197  		return io.ErrUnexpectedEOF
 198  	}
 199  	return nil
 200  }
 201  
 202  func (lt *lifetime) marshal(b *cryptobyte.Builder) {
 203  	b.AddUint64(lt.notBefore)
 204  	b.AddUint64(lt.notAfter)
 205  }
 206  
 207  func (lt *lifetime) notBeforeTime() time.Time {
 208  	return time.Unix(int64(lt.notBefore), 0)
 209  }
 210  
 211  func (lt *lifetime) notAfterTime() time.Time {
 212  	return time.Unix(int64(lt.notAfter), 0)
 213  }
 214  
 215  // verify ensures that the lifetime is valid: it has an acceptable range and
 216  // the current time is within that range.
 217  func (lt *lifetime) verify(t time.Time) bool {
 218  	notBefore, notAfter := lt.notBeforeTime(), lt.notAfterTime()
 219  
 220  	if d := notAfter.Sub(notBefore); d <= 0 || d > maxLeafNodeLifetime {
 221  		return false
 222  	}
 223  
 224  	return t.After(notBefore) && notAfter.After(t)
 225  }
 226  
 227  type extensionType uint16
 228  
 229  // http://www.iana.org/assignments/mls/mls.xhtml#mls-extension-types
 230  const (
 231  	extensionTypeApplicationID        extensionType = 0x0001
 232  	extensionTypeRatchetTree          extensionType = 0x0002
 233  	extensionTypeRequiredCapabilities extensionType = 0x0003
 234  	extensionTypeExternalPub          extensionType = 0x0004
 235  	extensionTypeExternalSenders      extensionType = 0x0005
 236  
 237  	// ExtensionTypeLastResort marks a KeyPackage as reusable for multiple
 238  	// Welcome messages. Required by Marmot (MIP-00).
 239  	ExtensionTypeLastResort extensionType = 0x000a
 240  
 241  	// ExtensionTypeNostrGroupData carries Nostr group metadata (group ID,
 242  	// name, admins, relays). Required by Marmot (MIP-01).
 243  	ExtensionTypeNostrGroupData extensionType = 0xf2ee
 244  )
 245  
 246  // Extension holds a TLS-serialized MLS extension (type + opaque data).
 247  type Extension = extension
 248  
 249  type extension struct {
 250  	extensionType extensionType
 251  	extensionData []byte
 252  }
 253  
 254  // NewExtension creates an extension with the given type and data.
 255  func NewExtension(t extensionType, data []byte) extension {
 256  	return extension{extensionType: t, extensionData: data}
 257  }
 258  
 259  // ExtensionType is exported for use by the Marmot SDK.
 260  type ExtensionType = extensionType
 261  
 262  func unmarshalExtensionVec(s *cryptobyte.String) ([]extension, error) {
 263  	var exts []extension
 264  	err := readVector(s, func(s *cryptobyte.String) error {
 265  		var ext extension
 266  		if !s.ReadUint16((*uint16)(&ext.extensionType)) || !readOpaqueVec(s, &ext.extensionData) {
 267  			return io.ErrUnexpectedEOF
 268  		}
 269  		exts = append(exts, ext)
 270  		return nil
 271  	})
 272  	return exts, err
 273  }
 274  
 275  func marshalExtensionVec(b *cryptobyte.Builder, exts []extension) {
 276  	writeVector(b, len(exts), func(b *cryptobyte.Builder, i int) {
 277  		ext := exts[i]
 278  		b.AddUint16(uint16(ext.extensionType))
 279  		writeOpaqueVec(b, ext.extensionData)
 280  	})
 281  }
 282  
 283  func findExtensionData(exts []extension, t extensionType) []byte {
 284  	for _, ext := range exts {
 285  		if ext.extensionType == t {
 286  			return ext.extensionData
 287  		}
 288  	}
 289  	return nil
 290  }
 291  
 292  type leafNode struct {
 293  	encryptionKey hpkePublicKey
 294  	signatureKey  signaturePublicKey
 295  	credential    Credential
 296  	capabilities  capabilities
 297  
 298  	leafNodeSource leafNodeSource
 299  	lifetime       *lifetime // for leafNodeSourceKeyPackage
 300  	parentHash     []byte    // for leafNodeSourceCommit
 301  
 302  	extensions []extension
 303  	signature  []byte
 304  }
 305  
 306  func (node *leafNode) unmarshal(s *cryptobyte.String) error {
 307  	*node = leafNode{}
 308  
 309  	if !readOpaqueVec(s, (*[]byte)(&node.encryptionKey)) || !readOpaqueVec(s, (*[]byte)(&node.signatureKey)) {
 310  		return io.ErrUnexpectedEOF
 311  	}
 312  
 313  	if err := node.credential.unmarshal(s); err != nil {
 314  		return err
 315  	}
 316  	if err := node.capabilities.unmarshal(s); err != nil {
 317  		return err
 318  	}
 319  	if err := node.leafNodeSource.unmarshal(s); err != nil {
 320  		return err
 321  	}
 322  
 323  	var err error
 324  	switch node.leafNodeSource {
 325  	case leafNodeSourceKeyPackage:
 326  		node.lifetime = new(lifetime)
 327  		err = node.lifetime.unmarshal(s)
 328  	case leafNodeSourceCommit:
 329  		if !readOpaqueVec(s, &node.parentHash) {
 330  			err = io.ErrUnexpectedEOF
 331  		}
 332  	}
 333  	if err != nil {
 334  		return err
 335  	}
 336  
 337  	exts, err := unmarshalExtensionVec(s)
 338  	if err != nil {
 339  		return err
 340  	}
 341  	node.extensions = exts
 342  
 343  	if !readOpaqueVec(s, &node.signature) {
 344  		return io.ErrUnexpectedEOF
 345  	}
 346  
 347  	return nil
 348  }
 349  
 350  func (node *leafNode) marshalBase(b *cryptobyte.Builder) {
 351  	writeOpaqueVec(b, []byte(node.encryptionKey))
 352  	writeOpaqueVec(b, []byte(node.signatureKey))
 353  	node.credential.marshal(b)
 354  	node.capabilities.marshal(b)
 355  	node.leafNodeSource.marshal(b)
 356  	switch node.leafNodeSource {
 357  	case leafNodeSourceKeyPackage:
 358  		node.lifetime.marshal(b)
 359  	case leafNodeSourceCommit:
 360  		writeOpaqueVec(b, node.parentHash)
 361  	}
 362  	marshalExtensionVec(b, node.extensions)
 363  }
 364  
 365  func (node *leafNode) marshal(b *cryptobyte.Builder) {
 366  	node.marshalBase(b)
 367  	writeOpaqueVec(b, []byte(node.signature))
 368  }
 369  
 370  type leafNodeTBS struct {
 371  	*leafNode
 372  
 373  	// for leafNodeSourceUpdate and leafNodeSourceCommit
 374  	groupID   GroupID
 375  	leafIndex leafIndex
 376  }
 377  
 378  func (node *leafNodeTBS) marshal(b *cryptobyte.Builder) {
 379  	node.leafNode.marshalBase(b)
 380  	switch node.leafNode.leafNodeSource {
 381  	case leafNodeSourceUpdate, leafNodeSourceCommit:
 382  		writeOpaqueVec(b, []byte(node.groupID))
 383  		b.AddUint32(uint32(node.leafIndex))
 384  	}
 385  }
 386  
 387  func (node *leafNode) sign(cs CipherSuite, groupID GroupID, li leafIndex, signerPriv signaturePrivateKey) error {
 388  	leafNodeTBS, err := marshal(&leafNodeTBS{
 389  		leafNode:  node,
 390  		groupID:   groupID,
 391  		leafIndex: li,
 392  	})
 393  	if err != nil {
 394  		return err
 395  	}
 396  	sig, err := cs.signWithLabel(signerPriv, []byte("LeafNodeTBS"), leafNodeTBS)
 397  	if err != nil {
 398  		return err
 399  	}
 400  	node.signature = sig
 401  	return nil
 402  }
 403  
 404  // verifySignature verifies the signature of the leaf node.
 405  //
 406  // groupID and li can be left unspecified if the leaf node source is neither
 407  // update nor commit.
 408  func (node *leafNode) verifySignature(cs CipherSuite, groupID GroupID, li leafIndex) bool {
 409  	leafNodeTBS, err := marshal(&leafNodeTBS{
 410  		leafNode:  node,
 411  		groupID:   groupID,
 412  		leafIndex: li,
 413  	})
 414  	if err != nil {
 415  		return false
 416  	}
 417  	return cs.verifyWithLabel(node.signatureKey, []byte("LeafNodeTBS"), leafNodeTBS, node.signature)
 418  }
 419  
 420  // verify performs leaf node validation described in section 7.3.
 421  //
 422  // It does not perform all checks: it does not check that the credential is
 423  // valid.
 424  func (node *leafNode) verify(options *leafNodeVerifyOptions) error {
 425  	li := options.leafIndex
 426  
 427  	if !node.verifySignature(options.cipherSuite, options.groupID, li) {
 428  		return fmt.Errorf("mls: leaf node signature verification failed")
 429  	}
 430  
 431  	// TODO: check required_capabilities group extension
 432  
 433  	if _, ok := options.supportedCreds[node.credential.credentialType]; !ok {
 434  		return fmt.Errorf("mls: credential type %v used by leaf node not supported by all members", node.credential.credentialType)
 435  	}
 436  
 437  	if node.lifetime != nil {
 438  		now := options.now
 439  		if now == nil {
 440  			now = time.Now
 441  		}
 442  		if t := now(); !t.IsZero() && !node.lifetime.verify(t) {
 443  			return fmt.Errorf("mls: lifetime verification failed (not before %v, not after %v)", node.lifetime.notBeforeTime(), node.lifetime.notAfterTime())
 444  		}
 445  	}
 446  
 447  	supportedExts := make(map[extensionType]struct{})
 448  	for _, et := range node.capabilities.extensions {
 449  		supportedExts[et] = struct{}{}
 450  	}
 451  	for _, ext := range node.extensions {
 452  		if _, ok := supportedExts[ext.extensionType]; !ok {
 453  			return fmt.Errorf("mls: extension type %d used by leaf node not supported by that leaf node", ext.extensionType)
 454  		}
 455  	}
 456  
 457  	// TODO: verify the leaf_node_source field
 458  
 459  	if _, dup := options.signatureKeys[string(node.signatureKey)]; dup {
 460  		return fmt.Errorf("mls: duplicate signature key in ratchet tree")
 461  	}
 462  	if _, dup := options.encryptionKeys[string(node.encryptionKey)]; dup {
 463  		return fmt.Errorf("mls: duplicate encryption key in ratchet tree")
 464  	}
 465  
 466  	return nil
 467  }
 468  
 469  type leafNodeVerifyOptions struct {
 470  	cipherSuite    CipherSuite
 471  	groupID        GroupID
 472  	leafIndex      leafIndex
 473  	supportedCreds map[credentialType]struct{}
 474  	signatureKeys  map[string]struct{}
 475  	encryptionKeys map[string]struct{}
 476  	now            func() time.Time
 477  }
 478  
 479  type updatePathNode struct {
 480  	encryptionKey       hpkePublicKey
 481  	encryptedPathSecret []hpkeCiphertext
 482  }
 483  
 484  func (node *updatePathNode) unmarshal(s *cryptobyte.String) error {
 485  	*node = updatePathNode{}
 486  
 487  	if !readOpaqueVec(s, (*[]byte)(&node.encryptionKey)) {
 488  		return io.ErrUnexpectedEOF
 489  	}
 490  
 491  	return readVector(s, func(s *cryptobyte.String) error {
 492  		var ciphertext hpkeCiphertext
 493  		if err := ciphertext.unmarshal(s); err != nil {
 494  			return err
 495  		}
 496  		node.encryptedPathSecret = append(node.encryptedPathSecret, ciphertext)
 497  		return nil
 498  	})
 499  }
 500  
 501  func (node *updatePathNode) marshal(b *cryptobyte.Builder) {
 502  	writeOpaqueVec(b, []byte(node.encryptionKey))
 503  	writeVector(b, len(node.encryptedPathSecret), func(b *cryptobyte.Builder, i int) {
 504  		node.encryptedPathSecret[i].marshal(b)
 505  	})
 506  }
 507  
 508  func decryptPathSecret(cs CipherSuite, nodePriv hpkePrivateKey, ctx *groupContext, ciphertext hpkeCiphertext) ([]byte, error) {
 509  	rawCtx, err := marshal(ctx)
 510  	if err != nil {
 511  		return nil, err
 512  	}
 513  	return cs.decryptWithLabel(nodePriv, []byte("UpdatePathNode"), rawCtx, ciphertext.kemOutput, ciphertext.ciphertext)
 514  }
 515  
 516  func nodePrivFromPathSecret(cs CipherSuite, pathSecret []byte, nodePub hpkePublicKey) (hpkePrivateKey, error) {
 517  	nodeSecret, err := cs.deriveSecret(pathSecret, []byte("node"))
 518  	if err != nil {
 519  		return nil, err
 520  	}
 521  
 522  	pub, priv, err := cs.deriveEncryptionKeyPair(nodeSecret)
 523  	if err != nil {
 524  		return nil, err
 525  	}
 526  
 527  	if !bytes.Equal(pub, nodePub) {
 528  		return nil, fmt.Errorf("mls: node public key mismatch")
 529  	}
 530  
 531  	return priv, nil
 532  }
 533  
 534  type updatePath struct {
 535  	leafNode leafNode
 536  	nodes    []updatePathNode
 537  }
 538  
 539  func (up *updatePath) unmarshal(s *cryptobyte.String) error {
 540  	*up = updatePath{}
 541  
 542  	if err := up.leafNode.unmarshal(s); err != nil {
 543  		return err
 544  	}
 545  
 546  	return readVector(s, func(s *cryptobyte.String) error {
 547  		var node updatePathNode
 548  		if err := node.unmarshal(s); err != nil {
 549  			return err
 550  		}
 551  		up.nodes = append(up.nodes, node)
 552  		return nil
 553  	})
 554  }
 555  
 556  func (up *updatePath) marshal(b *cryptobyte.Builder) {
 557  	up.leafNode.marshal(b)
 558  	writeVector(b, len(up.nodes), func(b *cryptobyte.Builder, i int) {
 559  		up.nodes[i].marshal(b)
 560  	})
 561  }
 562  
 563  type nodeType uint8
 564  
 565  const (
 566  	nodeTypeLeaf   nodeType = 1
 567  	nodeTypeParent nodeType = 2
 568  )
 569  
 570  func (t *nodeType) unmarshal(s *cryptobyte.String) error {
 571  	if !s.ReadUint8((*uint8)(t)) {
 572  		return io.ErrUnexpectedEOF
 573  	}
 574  	switch *t {
 575  	case nodeTypeLeaf, nodeTypeParent:
 576  		return nil
 577  	default:
 578  		return fmt.Errorf("mls: invalid node type %d", *t)
 579  	}
 580  }
 581  
 582  func (t nodeType) marshal(b *cryptobyte.Builder) {
 583  	b.AddUint8(uint8(t))
 584  }
 585  
 586  type node struct {
 587  	nodeType   nodeType
 588  	leafNode   *leafNode   // for nodeTypeLeaf
 589  	parentNode *parentNode // for nodeTypeParent
 590  }
 591  
 592  func (n *node) unmarshal(s *cryptobyte.String) error {
 593  	*n = node{}
 594  
 595  	if err := n.nodeType.unmarshal(s); err != nil {
 596  		return err
 597  	}
 598  
 599  	switch n.nodeType {
 600  	case nodeTypeLeaf:
 601  		n.leafNode = new(leafNode)
 602  		return n.leafNode.unmarshal(s)
 603  	case nodeTypeParent:
 604  		n.parentNode = new(parentNode)
 605  		return n.parentNode.unmarshal(s)
 606  	default:
 607  		panic("unreachable")
 608  	}
 609  }
 610  
 611  func (n *node) marshal(b *cryptobyte.Builder) {
 612  	n.nodeType.marshal(b)
 613  	switch n.nodeType {
 614  	case nodeTypeLeaf:
 615  		n.leafNode.marshal(b)
 616  	case nodeTypeParent:
 617  		n.parentNode.marshal(b)
 618  	default:
 619  		panic("unreachable")
 620  	}
 621  }
 622  
 623  func (n *node) encryptionKey() hpkePublicKey {
 624  	switch n.nodeType {
 625  	case nodeTypeLeaf:
 626  		return n.leafNode.encryptionKey
 627  	case nodeTypeParent:
 628  		return n.parentNode.encryptionKey
 629  	default:
 630  		panic("unreachable")
 631  	}
 632  }
 633  
 634  // ratchetTree is a ratchet tree represented as complete balanced binary tree,
 635  // stored with the array-based scheme described in appendix C.
 636  //
 637  // The length of the tree plus 1 is guaranteed to be a power of 2.
 638  type ratchetTree []*node
 639  
 640  func (tree *ratchetTree) unmarshal(s *cryptobyte.String) error {
 641  	*tree = ratchetTree{}
 642  	err := readVector(s, func(s *cryptobyte.String) error {
 643  		var n *node
 644  		var hasNode bool
 645  		if !readOptional(s, &hasNode) {
 646  			return io.ErrUnexpectedEOF
 647  		} else if hasNode {
 648  			n = new(node)
 649  			if err := n.unmarshal(s); err != nil {
 650  				return err
 651  			}
 652  		}
 653  		*tree = append(*tree, n)
 654  		return nil
 655  	})
 656  	if err != nil {
 657  		return err
 658  	}
 659  
 660  	// The raw tree doesn't include blank nodes at the end, fill it until next
 661  	// power of 2
 662  	for !isPowerOf2(uint32(len(*tree) + 1)) {
 663  		*tree = append(*tree, nil)
 664  	}
 665  
 666  	return nil
 667  }
 668  
 669  func (tree ratchetTree) marshal(b *cryptobyte.Builder) {
 670  	end := len(tree)
 671  	for end > 0 && tree[end-1] == nil {
 672  		end--
 673  	}
 674  
 675  	writeVector(b, len(tree[:end]), func(b *cryptobyte.Builder, i int) {
 676  		n := tree[i]
 677  		writeOptional(b, n != nil)
 678  		if n != nil {
 679  			n.marshal(b)
 680  		}
 681  	})
 682  }
 683  
 684  func (tree ratchetTree) copy() ratchetTree {
 685  	newTree := make(ratchetTree, len(tree))
 686  	copy(newTree, tree)
 687  	return newTree
 688  }
 689  
 690  // get returns the node at the provided index.
 691  //
 692  // nil is returned for blank nodes. get panics if the index is out of range.
 693  func (tree ratchetTree) get(i nodeIndex) *node {
 694  	return tree[int(i)]
 695  }
 696  
 697  func (tree ratchetTree) set(i nodeIndex, node *node) {
 698  	tree[int(i)] = node
 699  }
 700  
 701  func (tree ratchetTree) getLeaf(li leafIndex) *leafNode {
 702  	node := tree.get(li.nodeIndex())
 703  	if node == nil {
 704  		return nil
 705  	}
 706  	if node.nodeType != nodeTypeLeaf {
 707  		panic("unreachable")
 708  	}
 709  	return node.leafNode
 710  }
 711  
 712  // resolve computes the resolution of a node.
 713  func (tree ratchetTree) resolve(x nodeIndex) []nodeIndex {
 714  	n := tree.get(x)
 715  	if n == nil {
 716  		l, r, ok := x.children()
 717  		if !ok {
 718  			return nil // leaf
 719  		}
 720  		return append(tree.resolve(l), tree.resolve(r)...)
 721  	} else {
 722  		res := []nodeIndex{x}
 723  		if n.nodeType == nodeTypeParent {
 724  			for _, leafIndex := range n.parentNode.unmergedLeaves {
 725  				res = append(res, leafIndex.nodeIndex())
 726  			}
 727  		}
 728  		return res
 729  	}
 730  }
 731  
 732  func (tree ratchetTree) supportedCreds() map[credentialType]struct{} {
 733  	numMembers := 0
 734  	supportedCredsCount := make(map[credentialType]int)
 735  	for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ {
 736  		node := tree.getLeaf(li)
 737  		if node == nil {
 738  			continue
 739  		}
 740  
 741  		numMembers++
 742  		for _, ct := range node.capabilities.credentials {
 743  			supportedCredsCount[ct]++
 744  		}
 745  	}
 746  
 747  	supportedCreds := make(map[credentialType]struct{})
 748  	for ct, n := range supportedCredsCount {
 749  		if n == numMembers {
 750  			supportedCreds[ct] = struct{}{}
 751  		}
 752  	}
 753  
 754  	return supportedCreds
 755  }
 756  
 757  func (tree ratchetTree) keys() (signatureKeys, encryptionKeys map[string]struct{}) {
 758  	signatureKeys = make(map[string]struct{})
 759  	encryptionKeys = make(map[string]struct{})
 760  	for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ {
 761  		node := tree.getLeaf(li)
 762  		if node == nil {
 763  			continue
 764  		}
 765  		signatureKeys[string(node.signatureKey)] = struct{}{}
 766  		encryptionKeys[string(node.encryptionKey)] = struct{}{}
 767  	}
 768  	return signatureKeys, encryptionKeys
 769  }
 770  
 771  // verifyIntegrity verifies the integrity of the ratchet tree, as described in
 772  // section 12.4.3.1.
 773  //
 774  // This function does not perform full leaf node validation. In particular:
 775  //
 776  //   - It doesn't check that credentials are valid.
 777  //   - It doesn't check the lifetime field.
 778  func (tree ratchetTree) verifyIntegrity(ctx *groupContext, now func() time.Time) error {
 779  	cs := ctx.cipherSuite
 780  	numLeaves := tree.numLeaves()
 781  
 782  	if h, err := tree.computeRootTreeHash(cs); err != nil {
 783  		return err
 784  	} else if !bytes.Equal(h, ctx.treeHash) {
 785  		return fmt.Errorf("mls: tree hash verification failed")
 786  	}
 787  
 788  	if !tree.verifyParentHashes(cs) {
 789  		return fmt.Errorf("mls: parent hashes verification failed")
 790  	}
 791  
 792  	supportedCreds := tree.supportedCreds()
 793  	signatureKeys := make(map[string]struct{})
 794  	encryptionKeys := make(map[string]struct{})
 795  	for li := leafIndex(0); li < leafIndex(numLeaves); li++ {
 796  		node := tree.getLeaf(li)
 797  		if node == nil {
 798  			continue
 799  		}
 800  
 801  		err := node.verify(&leafNodeVerifyOptions{
 802  			cipherSuite:    cs,
 803  			groupID:        ctx.groupID,
 804  			leafIndex:      li,
 805  			supportedCreds: supportedCreds,
 806  			signatureKeys:  signatureKeys,
 807  			encryptionKeys: encryptionKeys,
 808  			now:            now,
 809  		})
 810  		if err != nil {
 811  			return fmt.Errorf("leaf node at index %v: %v", li, err)
 812  		}
 813  
 814  		signatureKeys[string(node.signatureKey)] = struct{}{}
 815  		encryptionKeys[string(node.encryptionKey)] = struct{}{}
 816  	}
 817  
 818  	for i, node := range tree {
 819  		if node == nil || node.nodeType != nodeTypeParent {
 820  			continue
 821  		}
 822  		p := nodeIndex(i)
 823  		for _, unmergedLeaf := range node.parentNode.unmergedLeaves {
 824  			x := unmergedLeaf.nodeIndex()
 825  			for {
 826  				var ok bool
 827  				if x, ok = numLeaves.parent(x); !ok {
 828  					return fmt.Errorf("mls: unmerged leaf %v is not a descendant of the parent node at index %v", unmergedLeaf, p)
 829  				} else if x == p {
 830  					break
 831  				}
 832  
 833  				intermediateNode := tree.get(x)
 834  				if intermediateNode != nil && !hasUnmergedLeaf(intermediateNode.parentNode, unmergedLeaf) {
 835  					return fmt.Errorf("mls: non-blank intermediate node at index %v is missing unmerged leaf %v", x, unmergedLeaf)
 836  				}
 837  			}
 838  		}
 839  
 840  		if _, dup := encryptionKeys[string(node.parentNode.encryptionKey)]; dup {
 841  			return fmt.Errorf("mls: duplicate encryption key in ratchet tree")
 842  		}
 843  		encryptionKeys[string(node.parentNode.encryptionKey)] = struct{}{}
 844  	}
 845  
 846  	return nil
 847  }
 848  
 849  func hasUnmergedLeaf(node *parentNode, unmergedLeaf leafIndex) bool {
 850  	for _, li := range node.unmergedLeaves {
 851  		if li == unmergedLeaf {
 852  			return true
 853  		}
 854  	}
 855  	return false
 856  }
 857  
 858  func (tree ratchetTree) computeRootTreeHash(cs CipherSuite) ([]byte, error) {
 859  	return tree.computeTreeHash(cs, tree.numLeaves().root(), nil)
 860  }
 861  
 862  func (tree ratchetTree) computeTreeHash(cs CipherSuite, x nodeIndex, exclude map[leafIndex]struct{}) ([]byte, error) {
 863  	n := tree.get(x)
 864  
 865  	var b cryptobyte.Builder
 866  	if li, ok := x.leafIndex(); ok {
 867  		_, excluded := exclude[li]
 868  
 869  		var l *leafNode
 870  		if n != nil && !excluded {
 871  			l = n.leafNode
 872  			if l == nil {
 873  				panic("unreachable")
 874  			}
 875  		}
 876  
 877  		marshalLeafNodeHashInput(&b, li, l)
 878  	} else {
 879  		left, right, ok := x.children()
 880  		if !ok {
 881  			panic("unreachable")
 882  		}
 883  
 884  		leftHash, err := tree.computeTreeHash(cs, left, exclude)
 885  		if err != nil {
 886  			return nil, err
 887  		}
 888  		rightHash, err := tree.computeTreeHash(cs, right, exclude)
 889  		if err != nil {
 890  			return nil, err
 891  		}
 892  
 893  		var p *parentNode
 894  		if n != nil {
 895  			p = n.parentNode
 896  			if p == nil {
 897  				panic("unreachable")
 898  			}
 899  
 900  			if len(p.unmergedLeaves) > 0 && len(exclude) > 0 {
 901  				unmergedLeaves := make([]leafIndex, 0, len(p.unmergedLeaves))
 902  				for _, li := range p.unmergedLeaves {
 903  					if _, excluded := exclude[li]; !excluded {
 904  						unmergedLeaves = append(unmergedLeaves, li)
 905  					}
 906  				}
 907  
 908  				filteredParent := *p
 909  				filteredParent.unmergedLeaves = unmergedLeaves
 910  				p = &filteredParent
 911  			}
 912  		}
 913  
 914  		marshalParentNodeHashInput(&b, p, leftHash, rightHash)
 915  	}
 916  	in, err := b.Bytes()
 917  	if err != nil {
 918  		return nil, err
 919  	}
 920  
 921  	h := cs.hash().New()
 922  	h.Write(in)
 923  	return h.Sum(nil), nil
 924  }
 925  
 926  func marshalLeafNodeHashInput(b *cryptobyte.Builder, i leafIndex, node *leafNode) {
 927  	b.AddUint8(uint8(nodeTypeLeaf))
 928  	b.AddUint32(uint32(i))
 929  	writeOptional(b, node != nil)
 930  	if node != nil {
 931  		node.marshal(b)
 932  	}
 933  }
 934  
 935  func marshalParentNodeHashInput(b *cryptobyte.Builder, node *parentNode, leftHash, rightHash []byte) {
 936  	b.AddUint8(uint8(nodeTypeParent))
 937  	writeOptional(b, node != nil)
 938  	if node != nil {
 939  		node.marshal(b)
 940  	}
 941  	writeOpaqueVec(b, leftHash)
 942  	writeOpaqueVec(b, rightHash)
 943  }
 944  
 945  func (tree ratchetTree) verifyParentHashes(cs CipherSuite) bool {
 946  	for i, node := range tree {
 947  		if node == nil {
 948  			continue
 949  		}
 950  
 951  		x := nodeIndex(i)
 952  		l, r, ok := x.children()
 953  		if !ok {
 954  			continue
 955  		}
 956  
 957  		parentNode := node.parentNode
 958  		exclude := make(map[leafIndex]struct{}, len(parentNode.unmergedLeaves))
 959  		for _, li := range parentNode.unmergedLeaves {
 960  			exclude[li] = struct{}{}
 961  		}
 962  
 963  		leftTreeHash, err := tree.computeTreeHash(cs, l, exclude)
 964  		if err != nil {
 965  			return false
 966  		}
 967  		rightTreeHash, err := tree.computeTreeHash(cs, r, exclude)
 968  		if err != nil {
 969  			return false
 970  		}
 971  
 972  		leftParentHash, err := parentNode.computeParentHash(cs, rightTreeHash)
 973  		if err != nil {
 974  			return false
 975  		}
 976  		rightParentHash, err := parentNode.computeParentHash(cs, leftTreeHash)
 977  		if err != nil {
 978  			return false
 979  		}
 980  
 981  		isLeftDescendant := tree.findParentHash(tree.resolve(l), leftParentHash)
 982  		isRightDescendant := tree.findParentHash(tree.resolve(r), rightParentHash)
 983  		if isLeftDescendant == isRightDescendant {
 984  			return false
 985  		}
 986  	}
 987  	return true
 988  }
 989  
 990  func (tree ratchetTree) findParentHash(nodeIndices []nodeIndex, parentHash []byte) bool {
 991  	for _, x := range nodeIndices {
 992  		node := tree.get(x)
 993  		if node == nil {
 994  			continue
 995  		}
 996  		var h []byte
 997  		switch node.nodeType {
 998  		case nodeTypeLeaf:
 999  			h = node.leafNode.parentHash
1000  		case nodeTypeParent:
1001  			h = node.parentNode.parentHash
1002  		}
1003  		if bytes.Equal(h, parentHash) {
1004  			return true
1005  		}
1006  	}
1007  	return false
1008  }
1009  
1010  func (tree ratchetTree) numLeaves() numLeaves {
1011  	return numLeavesFromWidth(uint32(len(tree)))
1012  }
1013  
1014  func (tree ratchetTree) findLeaf(node *leafNode) (leafIndex, bool) {
1015  	for li := leafIndex(0); li < leafIndex(tree.numLeaves()); li++ {
1016  		n := tree.getLeaf(li)
1017  		if n == nil {
1018  			continue
1019  		}
1020  
1021  		// Encryption keys are unique
1022  		if !bytes.Equal(n.encryptionKey, node.encryptionKey) {
1023  			continue
1024  		}
1025  
1026  		// Make sure both nodes are identical
1027  		raw1, err1 := marshal(node)
1028  		raw2, err2 := marshal(n)
1029  		return li, err1 == nil && err2 == nil && bytes.Equal(raw1, raw2)
1030  	}
1031  	return 0, false
1032  }
1033  
1034  func (tree *ratchetTree) add(leafNode *leafNode) {
1035  	li := leafIndex(0)
1036  	var ni nodeIndex
1037  	found := false
1038  	for {
1039  		ni = li.nodeIndex()
1040  		if int(ni) >= len(*tree) {
1041  			break
1042  		}
1043  		if tree.get(ni) == nil {
1044  			found = true
1045  			break
1046  		}
1047  		li++
1048  	}
1049  	if !found {
1050  		newLen := ((len(*tree) + 1) * 2) - 1
1051  		for len(*tree) < newLen {
1052  			*tree = append(*tree, nil)
1053  		}
1054  	}
1055  
1056  	numLeaves := tree.numLeaves()
1057  	p := ni
1058  	for {
1059  		var ok bool
1060  		p, ok = numLeaves.parent(p)
1061  		if !ok {
1062  			break
1063  		}
1064  		node := tree.get(p)
1065  		if node != nil {
1066  			node.parentNode.unmergedLeaves = append(node.parentNode.unmergedLeaves, li)
1067  		}
1068  	}
1069  
1070  	tree.set(ni, &node{
1071  		nodeType: nodeTypeLeaf,
1072  		leafNode: leafNode,
1073  	})
1074  }
1075  
1076  func (tree ratchetTree) update(li leafIndex, leafNode *leafNode) {
1077  	ni := li.nodeIndex()
1078  
1079  	tree.set(ni, &node{
1080  		nodeType: nodeTypeLeaf,
1081  		leafNode: leafNode,
1082  	})
1083  
1084  	numLeaves := tree.numLeaves()
1085  	for {
1086  		var ok bool
1087  		ni, ok = numLeaves.parent(ni)
1088  		if !ok {
1089  			break
1090  		}
1091  
1092  		tree.set(ni, nil)
1093  	}
1094  }
1095  
1096  func (tree *ratchetTree) remove(li leafIndex) {
1097  	ni := li.nodeIndex()
1098  
1099  	numLeaves := tree.numLeaves()
1100  	for {
1101  		tree.set(ni, nil)
1102  
1103  		var ok bool
1104  		ni, ok = numLeaves.parent(ni)
1105  		if !ok {
1106  			break
1107  		}
1108  	}
1109  
1110  	li = leafIndex(numLeaves - 1)
1111  	lastPowerOf2 := len(*tree) + 1
1112  	for {
1113  		ni = li.nodeIndex()
1114  		if tree.get(ni) != nil {
1115  			break
1116  		}
1117  
1118  		if isPowerOf2(uint32(ni)) {
1119  			lastPowerOf2 = int(ni)
1120  		}
1121  
1122  		if li == 0 {
1123  			*tree = nil
1124  			return
1125  		}
1126  		li--
1127  	}
1128  
1129  	if lastPowerOf2 < len(*tree)+1 {
1130  		*tree = (*tree)[:lastPowerOf2-1]
1131  	}
1132  }
1133  
1134  func (tree ratchetTree) filteredDirectPath(x nodeIndex) []nodeIndex {
1135  	numLeaves := tree.numLeaves()
1136  
1137  	var path []nodeIndex
1138  	for {
1139  		p, ok := numLeaves.parent(x)
1140  		if !ok {
1141  			break
1142  		}
1143  
1144  		s, ok := numLeaves.sibling(x)
1145  		if !ok {
1146  			panic("unreachable")
1147  		}
1148  
1149  		if len(tree.resolve(s)) > 0 {
1150  			path = append(path, p)
1151  		}
1152  
1153  		x = p
1154  	}
1155  
1156  	return path
1157  }
1158  
1159  func (tree ratchetTree) mergeUpdatePath(cs CipherSuite, senderLeafIndex leafIndex, path *updatePath) error {
1160  	senderNodeIndex := senderLeafIndex.nodeIndex()
1161  	numLeaves := tree.numLeaves()
1162  
1163  	directPath := numLeaves.directPath(senderNodeIndex)
1164  	for _, ni := range directPath {
1165  		tree.set(ni, nil)
1166  	}
1167  
1168  	filteredDirectPath := tree.filteredDirectPath(senderNodeIndex)
1169  	if len(filteredDirectPath) != len(path.nodes) {
1170  		return fmt.Errorf("mls: UpdatePath has %v nodes, but filtered direct path has %v nodes", len(path.nodes), len(filteredDirectPath))
1171  	}
1172  	for i, ni := range filteredDirectPath {
1173  		pathNode := path.nodes[i]
1174  		tree.set(ni, &node{
1175  			nodeType: nodeTypeParent,
1176  			parentNode: &parentNode{
1177  				encryptionKey: pathNode.encryptionKey,
1178  			},
1179  		})
1180  	}
1181  
1182  	// Compute parent hashes, from root to leaf
1183  	var prevParentHash []byte
1184  	for i := len(filteredDirectPath) - 1; i >= 0; i-- {
1185  		ni := filteredDirectPath[i]
1186  		node := tree.get(ni).parentNode
1187  
1188  		l, r, ok := ni.children()
1189  		if !ok {
1190  			panic("unreachable")
1191  		}
1192  
1193  		s := l
1194  		found := false
1195  		for _, ni := range directPath {
1196  			if ni == s {
1197  				found = true
1198  				break
1199  			}
1200  		}
1201  		if s == senderNodeIndex || found {
1202  			s = r
1203  		}
1204  
1205  		treeHash, err := tree.computeTreeHash(cs, s, nil)
1206  		if err != nil {
1207  			return err
1208  		}
1209  
1210  		node.parentHash = prevParentHash
1211  		h, err := node.computeParentHash(cs, treeHash)
1212  		if err != nil {
1213  			return err
1214  		}
1215  		prevParentHash = h
1216  	}
1217  
1218  	if !bytes.Equal(path.leafNode.parentHash, prevParentHash) {
1219  		return fmt.Errorf("mls: parent hash mismatch for update path's leaf node")
1220  	}
1221  
1222  	tree.set(senderNodeIndex, &node{
1223  		nodeType: nodeTypeLeaf,
1224  		leafNode: &path.leafNode,
1225  	})
1226  
1227  	return nil
1228  }
1229  
1230  func (tree ratchetTree) decryptPathSecrets(cs CipherSuite, groupCtx *groupContext, senderLeafIndex, recipientLeafIndex leafIndex, path *updatePath, privTree []hpkePrivateKey) ([]byte, error) {
1231  	senderNodeIndex := senderLeafIndex.nodeIndex()
1232  	recipientNodeIndex := recipientLeafIndex.nodeIndex()
1233  
1234  	senderFilteredDirectPath := tree.filteredDirectPath(senderNodeIndex)
1235  	if len(path.nodes) != len(senderFilteredDirectPath) {
1236  		return nil, fmt.Errorf("mls: invalid UpdatePath length")
1237  	}
1238  
1239  	// Identify a node in the filtered direct path for which the recipient is
1240  	// in the subtree of the non-updated child
1241  	recipientAncestorIndex := -1
1242  	recipientAncestor := commonAncestor(senderNodeIndex, recipientNodeIndex)
1243  	for i, ni := range senderFilteredDirectPath {
1244  		if ni == recipientAncestor {
1245  			recipientAncestorIndex = i
1246  			break
1247  		}
1248  	}
1249  	if recipientAncestorIndex < 0 {
1250  		return nil, fmt.Errorf("mls: cannot find recipient ancestor")
1251  	}
1252  	updatePathNode := path.nodes[recipientAncestorIndex]
1253  
1254  	// Find the copath node
1255  	ancestor := commonAncestor(senderNodeIndex, recipientNodeIndex)
1256  	var (
1257  		copathNode nodeIndex
1258  		ok         bool
1259  	)
1260  	if recipientNodeIndex < senderNodeIndex {
1261  		copathNode, ok = ancestor.left()
1262  	} else {
1263  		copathNode, ok = ancestor.right()
1264  	}
1265  	if !ok {
1266  		panic("unreachable")
1267  	}
1268  
1269  	copathResolution := tree.resolve(copathNode)
1270  	if len(updatePathNode.encryptedPathSecret) != len(copathResolution) {
1271  		return nil, fmt.Errorf("mls: invalid UpdatePathNode.encrypted_path_secret length")
1272  	}
1273  
1274  	// Identify a node in the resolution of the copath node for which we have
1275  	// a private key
1276  	var nodePriv hpkePrivateKey
1277  	resolutionIndex := -1
1278  	for i, ni := range copathResolution {
1279  		if p := privTree[int(ni)]; p != nil {
1280  			nodePriv = p
1281  			resolutionIndex = i
1282  			break
1283  		}
1284  	}
1285  	if nodePriv == nil {
1286  		return nil, fmt.Errorf("mls: no private key found")
1287  	}
1288  	ciphertext := updatePathNode.encryptedPathSecret[resolutionIndex]
1289  
1290  	// Decrypt the path secret using the private key from the resolution node
1291  	pathSecret, err := decryptPathSecret(cs, nodePriv, groupCtx, ciphertext)
1292  	if err != nil {
1293  		return nil, fmt.Errorf("failed to decrypt path secret: %v", err)
1294  	}
1295  	nodePub := tree.get(recipientAncestor).encryptionKey()
1296  	nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, nodePub)
1297  	if err != nil {
1298  		return nil, fmt.Errorf("failed to derive node %v private key from path secret: %v", recipientAncestor, err)
1299  	}
1300  	privTree[int(recipientAncestor)] = nodePriv
1301  
1302  	// Derive path secrets for ancestors of that node in the sender's filtered
1303  	// direct path
1304  	for _, ni := range senderFilteredDirectPath[recipientAncestorIndex+1:] {
1305  		pathSecret, err = cs.deriveSecret(pathSecret, []byte("path"))
1306  		if err != nil {
1307  			return nil, fmt.Errorf("failed to derive path secret: %v", err)
1308  		}
1309  		nodePriv, err := nodePrivFromPathSecret(cs, pathSecret, tree.get(ni).encryptionKey())
1310  		if err != nil {
1311  			return nil, fmt.Errorf("failed to derive node %v private key from path secret: %v", ni, err)
1312  		}
1313  		privTree[int(ni)] = nodePriv
1314  	}
1315  
1316  	commitSecret, err := cs.deriveSecret(pathSecret, []byte("path"))
1317  	if err != nil {
1318  		return nil, fmt.Errorf("failed to derive commit secret: %v", err)
1319  	}
1320  
1321  	return commitSecret, nil
1322  }
1323  
1324  func (tree *ratchetTree) apply(proposals []proposal, senders []leafIndex) {
1325  	// Apply all update proposals
1326  	for i, prop := range proposals {
1327  		if prop.proposalType == proposalTypeUpdate {
1328  			tree.update(senders[i], &prop.update.leafNode)
1329  		}
1330  	}
1331  
1332  	// Apply all remove proposals
1333  	for _, prop := range proposals {
1334  		if prop.proposalType == proposalTypeRemove {
1335  			tree.remove(prop.remove.removed)
1336  		}
1337  	}
1338  
1339  	// Apply all add proposals
1340  	for _, prop := range proposals {
1341  		if prop.proposalType == proposalTypeAdd {
1342  			tree.add(&prop.add.keyPackage.leafNode)
1343  		}
1344  	}
1345  }
1346