tree_crypto.mx raw

   1  package mls
   2  
   3  // MLS tree crypto operations (RFC 9420 §7).
   4  
   5  import "errors"
   6  
   7  var (
   8  	errNodeKeyMismatch = errors.New("mls: node public key mismatch")
   9  )
  10  
  11  func (node *parentNode) computeParentHash(cs CipherSuite, originalSiblingTreeHash []byte) ([]byte, error) {
  12  	raw, err := marshalParentHashInput(node.encryptionKey, node.parentHash, originalSiblingTreeHash)
  13  	if err != nil {
  14  		return nil, err
  15  	}
  16  	return cs.hash(raw), nil
  17  }
  18  
  19  func marshalParentHashInput(encryptionKey hpkePublicKey, parentHash, originalSiblingTreeHash []byte) ([]byte, error) {
  20  	var w Writer
  21  	w.writeOpaqueVec([]byte(encryptionKey))
  22  	w.writeOpaqueVec(parentHash)
  23  	w.writeOpaqueVec(originalSiblingTreeHash)
  24  	return w.bytes()
  25  }
  26  
  27  func (node *leafNode) sign(cs CipherSuite, groupID GroupID, li leafIndex, signerPriv signaturePrivateKey) error {
  28  	tbs, err := marshalRaw(&leafNodeTBS{
  29  		node:      node,
  30  		groupID:   groupID,
  31  		leafIndex: li,
  32  	})
  33  	if err != nil {
  34  		return err
  35  	}
  36  	sig, err := cs.signWithLabel(signerPriv, []byte("LeafNodeTBS"), tbs)
  37  	if err != nil {
  38  		return err
  39  	}
  40  	node.signature = sig
  41  	return nil
  42  }
  43  
  44  func (node *leafNode) verifySignature(cs CipherSuite, groupID GroupID, li leafIndex) bool {
  45  	tbs, err := marshalRaw(&leafNodeTBS{
  46  		node:      node,
  47  		groupID:   groupID,
  48  		leafIndex: li,
  49  	})
  50  	if err != nil {
  51  		return false
  52  	}
  53  	return cs.verifyWithLabel(node.signatureKey, []byte("LeafNodeTBS"), tbs, node.signature)
  54  }
  55  
  56  func decryptPathSecret(cs CipherSuite, nodePriv hpkePrivateKey, ctx *groupContext, ct hpkeCiphertext) ([]byte, error) {
  57  	rawCtx, err := marshalRaw(ctx)
  58  	if err != nil {
  59  		return nil, err
  60  	}
  61  	return cs.decryptWithLabel(nodePriv, []byte("UpdatePathNode"), rawCtx, ct.kemOutput, ct.ciphertext)
  62  }
  63  
  64  func nodePrivFromPathSecret(cs CipherSuite, pathSecret []byte, nodePub hpkePublicKey) (hpkePrivateKey, error) {
  65  	nodeSecret, err := cs.deriveSecret(pathSecret, []byte("node"))
  66  	if err != nil {
  67  		return nil, err
  68  	}
  69  	pub, priv, err := cs.deriveEncryptionKeyPair(nodeSecret)
  70  	if err != nil {
  71  		return nil, err
  72  	}
  73  	if !bytesEqual(pub, nodePub) {
  74  		return nil, errNodeKeyMismatch
  75  	}
  76  	return priv, nil
  77  }
  78  
  79  // computeRootTreeHash computes the tree hash for integrity verification.
  80  func (tree ratchetTree) computeRootTreeHash(cs CipherSuite) ([]byte, error) {
  81  	n := tree.numLeaves()
  82  	return tree.computeTreeHash(cs, n.root(), n, nil)
  83  }
  84  
  85  func (tree ratchetTree) computeTreeHash(cs CipherSuite, x nodeIndex, n numLeaves, exclude map[leafIndex]bool) ([]byte, error) {
  86  	if x.isLeaf() {
  87  		return tree.computeLeafTreeHash(cs, x, exclude)
  88  	}
  89  	return tree.computeParentTreeHash(cs, x, n, exclude)
  90  }
  91  
  92  func (tree ratchetTree) computeLeafTreeHash(cs CipherSuite, x nodeIndex, exclude map[leafIndex]bool) ([]byte, error) {
  93  	var w Writer
  94  	li, _ := x.leafIndex()
  95  	w.addUint32(uint32(li))
  96  	nd := tree.get(x)
  97  	if exclude[li] {
  98  		nd = nil
  99  	}
 100  	w.writeOptional(nd != nil)
 101  	if nd != nil {
 102  		nd.leafNode.marshal(&w)
 103  	}
 104  	input, err := w.bytes()
 105  	if err != nil {
 106  		return nil, err
 107  	}
 108  	return cs.hash(input), nil
 109  }
 110  
 111  func (tree ratchetTree) computeParentTreeHash(cs CipherSuite, x nodeIndex, n numLeaves, exclude map[leafIndex]bool) ([]byte, error) {
 112  	l, r, ok := x.children()
 113  	if !ok {
 114  		return nil, errUnexpectedEOF
 115  	}
 116  	leftHash, err := tree.computeTreeHash(cs, l, n, exclude)
 117  	if err != nil {
 118  		return nil, err
 119  	}
 120  	rightHash, err := tree.computeTreeHash(cs, r, n, exclude)
 121  	if err != nil {
 122  		return nil, err
 123  	}
 124  
 125  	var w Writer
 126  	nd := tree.get(x)
 127  	if nd != nil && len(exclude) > 0 {
 128  		// Filter unmerged leaves for parent hash verification
 129  		pn := nd.parentNode
 130  		if pn != nil && len(pn.unmergedLeaves) > 0 {
 131  			filtered := []leafIndex{:0:len(pn.unmergedLeaves)}
 132  			for _, li := range pn.unmergedLeaves {
 133  				if !exclude[li] {
 134  					filtered = append(filtered, li)
 135  				}
 136  			}
 137  			filteredNode := *pn
 138  			filteredNode.unmergedLeaves = filtered
 139  			w.writeOptional(true)
 140  			filteredNode.marshal(&w)
 141  			w.writeOpaqueVec(leftHash)
 142  			w.writeOpaqueVec(rightHash)
 143  			input, err := w.bytes()
 144  			if err != nil {
 145  				return nil, err
 146  			}
 147  			return cs.hash(input), nil
 148  		}
 149  	}
 150  	w.writeOptional(nd != nil)
 151  	if nd != nil {
 152  		nd.parentNode.marshal(&w)
 153  	}
 154  	w.writeOpaqueVec(leftHash)
 155  	w.writeOpaqueVec(rightHash)
 156  	input, err := w.bytes()
 157  	if err != nil {
 158  		return nil, err
 159  	}
 160  	return cs.hash(input), nil
 161  }
 162  
 163  // verifyParentHashes verifies the parent hash chain (RFC 9420 §7.9.2).
 164  func (tree ratchetTree) verifyParentHashes(cs CipherSuite) bool {
 165  	n := tree.numLeaves()
 166  	for i, nd := range tree {
 167  		if nd == nil {
 168  			continue
 169  		}
 170  		x := nodeIndex(i)
 171  		l, r, ok := x.children()
 172  		if !ok {
 173  			continue
 174  		}
 175  
 176  		pn := nd.parentNode
 177  		exclude := map[leafIndex]bool{}
 178  		for _, li := range pn.unmergedLeaves {
 179  			exclude[li] = true
 180  		}
 181  
 182  		leftTreeHash, err := tree.computeTreeHash(cs, l, n, exclude)
 183  		if err != nil {
 184  			return false
 185  		}
 186  		rightTreeHash, err := tree.computeTreeHash(cs, r, n, exclude)
 187  		if err != nil {
 188  			return false
 189  		}
 190  
 191  		leftParentHash, err := pn.computeParentHash(cs, rightTreeHash)
 192  		if err != nil {
 193  			return false
 194  		}
 195  		rightParentHash, err := pn.computeParentHash(cs, leftTreeHash)
 196  		if err != nil {
 197  			return false
 198  		}
 199  
 200  		isLeft := tree.findParentHash(tree.resolve(l), leftParentHash)
 201  		isRight := tree.findParentHash(tree.resolve(r), rightParentHash)
 202  		if isLeft == isRight {
 203  			return false
 204  		}
 205  	}
 206  	return true
 207  }
 208  
 209  // verifyIntegrity verifies the ratchet tree (RFC 9420 §12.4.3.1).
 210  func (tree ratchetTree) verifyIntegrity(ctx *groupContext, nowUnix int64) error {
 211  	cs := ctx.cipherSuite
 212  	n := tree.numLeaves()
 213  
 214  	h, err := tree.computeRootTreeHash(cs)
 215  	if err != nil {
 216  		return err
 217  	}
 218  	if !bytesEqual(h, ctx.treeHash) {
 219  		return errors.New("mls: tree hash verification failed")
 220  	}
 221  	if !tree.verifyParentHashes(cs) {
 222  		return errors.New("mls: parent hash verification failed")
 223  	}
 224  
 225  	supportedCreds := tree.supportedCreds()
 226  	sigKeys := map[string]bool{}
 227  	encKeys := map[string]bool{}
 228  	for li := leafIndex(0); li < leafIndex(n); li++ {
 229  		nd := tree.getLeaf(li)
 230  		if nd == nil {
 231  			continue
 232  		}
 233  		err := nd.verify(&leafNodeVerifyOptions{
 234  			cipherSuite:    cs,
 235  			groupID:        ctx.groupID,
 236  			leafIndex:      li,
 237  			supportedCreds: supportedCreds,
 238  			signatureKeys:  sigKeys,
 239  			encryptionKeys: encKeys,
 240  			nowUnix:        nowUnix,
 241  		})
 242  		if err != nil {
 243  			return err
 244  		}
 245  		sigKeys[string(nd.signatureKey)] = true
 246  		encKeys[string(nd.encryptionKey)] = true
 247  	}
 248  
 249  	// Check unmerged leaf ancestry
 250  	for i, nd := range tree {
 251  		if nd == nil || nd.nodeType != nodeTypeParent {
 252  			continue
 253  		}
 254  		p := nodeIndex(i)
 255  		for _, ul := range nd.parentNode.unmergedLeaves {
 256  			x := ul.nodeIndex()
 257  			for {
 258  				var ok bool
 259  				x, ok = n.parent(x)
 260  				if !ok {
 261  					return errors.New("mls: unmerged leaf not descendant of parent")
 262  				}
 263  				if x == p {
 264  					break
 265  				}
 266  				intermediate := tree.get(x)
 267  				if intermediate != nil && !hasUnmergedLeaf(intermediate.parentNode, ul) {
 268  					return errors.New("mls: intermediate node missing unmerged leaf")
 269  				}
 270  			}
 271  		}
 272  		if encKeys[string(nd.parentNode.encryptionKey)] {
 273  			return errors.New("mls: duplicate encryption key in tree")
 274  		}
 275  		encKeys[string(nd.parentNode.encryptionKey)] = true
 276  	}
 277  	return nil
 278  }
 279  
 280  // mergeUpdatePath applies an update path to the tree (RFC 9420 §7.5).
 281  func (tree ratchetTree) mergeUpdatePath(cs CipherSuite, senderLI leafIndex, path *updatePath) error {
 282  	senderNI := senderLI.nodeIndex()
 283  	n := tree.numLeaves()
 284  
 285  	directPath := n.directPath(senderNI)
 286  	for _, ni := range directPath {
 287  		tree.set(ni, nil)
 288  	}
 289  
 290  	filteredDP := tree.filteredDirectPath(senderNI)
 291  	if len(filteredDP) != len(path.nodes) {
 292  		return errors.New("mls: update path length mismatch")
 293  	}
 294  	for i, ni := range filteredDP {
 295  		tree.set(ni, &node{
 296  			nodeType: nodeTypeParent,
 297  			parentNode: &parentNode{
 298  				encryptionKey: path.nodes[i].encryptionKey,
 299  			},
 300  		})
 301  	}
 302  
 303  	// Compute parent hashes root-to-leaf
 304  	var prevParentHash []byte
 305  	for i := len(filteredDP) - 1; i >= 0; i-- {
 306  		ni := filteredDP[i]
 307  		pn := tree.get(ni).parentNode
 308  
 309  		l, r, ok := ni.children()
 310  		if !ok {
 311  			panic("unreachable")
 312  		}
 313  		s := l
 314  		found := false
 315  		for _, dp := range directPath {
 316  			if dp == s {
 317  				found = true
 318  				break
 319  			}
 320  		}
 321  		if s == senderNI || found {
 322  			s = r
 323  		}
 324  
 325  		treeHash, err := tree.computeTreeHash(cs, s, n, nil)
 326  		if err != nil {
 327  			return err
 328  		}
 329  
 330  		pn.parentHash = prevParentHash
 331  		h, err := pn.computeParentHash(cs, treeHash)
 332  		if err != nil {
 333  			return err
 334  		}
 335  		prevParentHash = h
 336  	}
 337  
 338  	if !bytesEqual(path.leafNode.parentHash, prevParentHash) {
 339  		return errors.New("mls: parent hash mismatch for update path leaf node")
 340  	}
 341  
 342  	tree.set(senderNI, &node{
 343  		nodeType: nodeTypeLeaf,
 344  		leafNode: &path.leafNode,
 345  	})
 346  	return nil
 347  }
 348  
 349  // decryptPathSecrets decrypts path secrets from an update path (RFC 9420 §7.6).
 350  func (tree ratchetTree) decryptPathSecrets(cs CipherSuite, ctx *groupContext, senderLI, recipientLI leafIndex, path *updatePath, privTree []hpkePrivateKey) ([]byte, error) {
 351  	senderNI := senderLI.nodeIndex()
 352  	recipientNI := recipientLI.nodeIndex()
 353  
 354  	senderFDP := tree.filteredDirectPath(senderNI)
 355  	if len(path.nodes) != len(senderFDP) {
 356  		return nil, errors.New("mls: invalid update path length")
 357  	}
 358  
 359  	// Find the common ancestor in the filtered direct path
 360  	recipientAncestor := commonAncestor(senderNI, recipientNI)
 361  	recipientAncestorIdx := -1
 362  	for i, ni := range senderFDP {
 363  		if ni == recipientAncestor {
 364  			recipientAncestorIdx = i
 365  			break
 366  		}
 367  	}
 368  	if recipientAncestorIdx < 0 {
 369  		return nil, errors.New("mls: cannot find recipient ancestor")
 370  	}
 371  	upNode := path.nodes[recipientAncestorIdx]
 372  
 373  	// Find the copath node
 374  	ancestor := commonAncestor(senderNI, recipientNI)
 375  	var copathNode nodeIndex
 376  	var ok bool
 377  	if recipientNI < senderNI {
 378  		copathNode, ok = ancestor.left()
 379  	} else {
 380  		copathNode, ok = ancestor.right()
 381  	}
 382  	if !ok {
 383  		panic("unreachable")
 384  	}
 385  
 386  	copathRes := tree.resolve(copathNode)
 387  	if len(upNode.encryptedPathSecret) != len(copathRes) {
 388  		return nil, errors.New("mls: invalid encrypted path secret length")
 389  	}
 390  
 391  	// Find a node in the resolution for which we have a private key
 392  	var nodePriv hpkePrivateKey
 393  	resIdx := -1
 394  	for i, ni := range copathRes {
 395  		if p := privTree[int(ni)]; p != nil {
 396  			nodePriv = p
 397  			resIdx = i
 398  			break
 399  		}
 400  	}
 401  	if nodePriv == nil {
 402  		return nil, errors.New("mls: no private key found")
 403  	}
 404  
 405  	pathSecret, err := decryptPathSecret(cs, nodePriv, ctx, upNode.encryptedPathSecret[resIdx])
 406  	if err != nil {
 407  		return nil, err
 408  	}
 409  	nodePub := tree.get(recipientAncestor).encryptionKey()
 410  	nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, nodePub)
 411  	if err != nil {
 412  		return nil, err
 413  	}
 414  	privTree[int(recipientAncestor)] = nodePriv
 415  
 416  	// Derive path secrets for remaining ancestors
 417  	for _, ni := range senderFDP[recipientAncestorIdx+1:] {
 418  		pathSecret, err = cs.deriveSecret(pathSecret, []byte("path"))
 419  		if err != nil {
 420  			return nil, err
 421  		}
 422  		nodePriv, err = nodePrivFromPathSecret(cs, pathSecret, tree.get(ni).encryptionKey())
 423  		if err != nil {
 424  			return nil, err
 425  		}
 426  		privTree[int(ni)] = nodePriv
 427  	}
 428  
 429  	commitSecret, err := cs.deriveSecret(pathSecret, []byte("path"))
 430  	if err != nil {
 431  		return nil, err
 432  	}
 433  	return commitSecret, nil
 434  }
 435