key_schedule_crypto.mx raw
1 package mls
2
3 // MLS key schedule crypto operations (RFC 9420 §8).
4 // Methods requiring CipherSuite.
5
6 import "errors"
7
8 func (ctx *groupContext) extractJoinerSecret(prevInitSecret, commitSecret []byte) ([]byte, error) {
9 cs := ctx.cipherSuite
10 extracted := cs.hkdfExtract(prevInitSecret, commitSecret)
11
12 rawCtx, err := marshalRaw(ctx)
13 if err != nil {
14 return nil, err
15 }
16 return cs.expandWithLabel(extracted, []byte("joiner"), rawCtx, uint16(cs.ExtractSize()))
17 }
18
19 func (ctx *groupContext) extractEpochSecret(joinerSecret, pskSecret []byte) ([]byte, error) {
20 cs := ctx.cipherSuite
21 if pskSecret == nil {
22 pskSecret = []byte{:cs.ExtractSize()}
23 }
24 extracted := cs.hkdfExtract(joinerSecret, pskSecret)
25
26 rawCtx, err := marshalRaw(ctx)
27 if err != nil {
28 return nil, err
29 }
30 return cs.expandWithLabel(extracted, []byte("epoch"), rawCtx, uint16(cs.ExtractSize()))
31 }
32
33 func (ctx *groupContext) signConfirmationTag(epochSecret []byte) ([]byte, error) {
34 cs := ctx.cipherSuite
35 confirmationKey, err := cs.deriveSecret(epochSecret, secretLabelConfirm)
36 if err != nil {
37 return nil, err
38 }
39 return cs.signMAC(confirmationKey, ctx.confirmedTranscriptHash), nil
40 }
41
42 func extractWelcomeSecret(cs CipherSuite, joinerSecret, pskSecret []byte) ([]byte, error) {
43 if pskSecret == nil {
44 pskSecret = []byte{:cs.ExtractSize()}
45 }
46 extracted := cs.hkdfExtract(joinerSecret, pskSecret)
47 return cs.deriveSecret(extracted, []byte("welcome"))
48 }
49
50 func deriveExporter(cs CipherSuite, exporterSecret, label, context []byte, length uint16) ([]byte, error) {
51 derived, err := cs.deriveSecret(exporterSecret, label)
52 if err != nil {
53 return nil, err
54 }
55 contextHash := cs.hash(context)
56 return cs.expandWithLabel(derived, []byte("exported"), contextHash, length)
57 }
58
59 func (input *confirmedTranscriptHashInput) hashValue(cs CipherSuite, interimTranscriptHashBefore []byte) ([]byte, error) {
60 rawInput, err := marshalRaw(input)
61 if err != nil {
62 return nil, err
63 }
64 data := append(interimTranscriptHashBefore, rawInput...)
65 return cs.hash(data), nil
66 }
67
68 func nextInterimTranscriptHash(cs CipherSuite, confirmedTranscriptHash, confirmationTag []byte) ([]byte, error) {
69 var w Writer
70 w.writeOpaqueVec(confirmationTag)
71 rawInput, err := w.bytes()
72 if err != nil {
73 return nil, err
74 }
75 data := append(confirmedTranscriptHash, rawInput...)
76 return cs.hash(data), nil
77 }
78
79 // extractPSKSecret derives the PSK secret from a list of PSKs (RFC 9420 §8.4).
80 func extractPSKSecret(cs CipherSuite, pskIDs []preSharedKeyID, psks [][]byte) ([]byte, error) {
81 if len(pskIDs) != len(psks) {
82 return nil, errors.New("mls: PSK ID count != PSK count")
83 }
84
85 zero := []byte{:cs.ExtractSize()}
86 pskSecret := zero
87 for i := range pskIDs {
88 pskExtracted := cs.hkdfExtract(psks[i], zero)
89
90 label := pskLabel{
91 id: pskIDs[i],
92 index: uint16(i),
93 count: uint16(len(pskIDs)),
94 }
95 rawLabel, err := marshalRaw(&label)
96 if err != nil {
97 return nil, err
98 }
99
100 pskInput, err := cs.expandWithLabel(pskExtracted, []byte("derived psk"), rawLabel, uint16(cs.ExtractSize()))
101 if err != nil {
102 return nil, err
103 }
104 pskSecret = cs.hkdfExtract(pskSecret, pskInput)
105 }
106 return pskSecret, nil
107 }
108
109 type pskLabel struct {
110 id preSharedKeyID
111 index uint16
112 count uint16
113 }
114
115 func (label *pskLabel) marshal(w *Writer) {
116 label.id.marshal(w)
117 w.addUint16(label.index)
118 w.addUint16(label.count)
119 }
120