key_schedule.go raw
1 package mls
2
3 import (
4 "fmt"
5 "io"
6
7 "golang.org/x/crypto/cryptobyte"
8 )
9
10 type groupContext struct {
11 version protocolVersion
12 cipherSuite CipherSuite
13 groupID GroupID
14 epoch uint64
15 treeHash []byte
16 confirmedTranscriptHash []byte
17 extensions []extension
18 }
19
20 func (ctx *groupContext) unmarshal(s *cryptobyte.String) error {
21 *ctx = groupContext{}
22
23 ok := s.ReadUint16((*uint16)(&ctx.version)) &&
24 s.ReadUint16((*uint16)(&ctx.cipherSuite)) &&
25 readOpaqueVec(s, (*[]byte)(&ctx.groupID)) &&
26 s.ReadUint64(&ctx.epoch) &&
27 readOpaqueVec(s, &ctx.treeHash) &&
28 readOpaqueVec(s, &ctx.confirmedTranscriptHash)
29 if !ok {
30 return io.ErrUnexpectedEOF
31 }
32
33 if ctx.version != protocolVersionMLS10 {
34 return fmt.Errorf("mls: invalid protocol version %d", ctx.version)
35 }
36
37 exts, err := unmarshalExtensionVec(s)
38 if err != nil {
39 return err
40 }
41 ctx.extensions = exts
42
43 return nil
44 }
45
46 func (ctx *groupContext) marshal(b *cryptobyte.Builder) {
47 b.AddUint16(uint16(ctx.version))
48 b.AddUint16(uint16(ctx.cipherSuite))
49 writeOpaqueVec(b, []byte(ctx.groupID))
50 b.AddUint64(ctx.epoch)
51 writeOpaqueVec(b, ctx.treeHash)
52 writeOpaqueVec(b, ctx.confirmedTranscriptHash)
53 marshalExtensionVec(b, ctx.extensions)
54 }
55
56 func (ctx *groupContext) extractJoinerSecret(prevInitSecret, commitSecret []byte) ([]byte, error) {
57 cs := ctx.cipherSuite
58 _, kdf, _ := cs.hpke().Params()
59
60 extracted := kdf.Extract(commitSecret, prevInitSecret)
61
62 rawGroupContext, err := marshal(ctx)
63 if err != nil {
64 return nil, err
65 }
66 return cs.expandWithLabel(extracted, []byte("joiner"), rawGroupContext, uint16(kdf.ExtractSize()))
67 }
68
69 func (ctx *groupContext) extractEpochSecret(joinerSecret, pskSecret []byte) ([]byte, error) {
70 cs := ctx.cipherSuite
71 _, kdf, _ := cs.hpke().Params()
72
73 // TODO de-duplicate with extractWelcomeSecret
74 if pskSecret == nil {
75 pskSecret = make([]byte, kdf.ExtractSize())
76 }
77 extracted := kdf.Extract(pskSecret, joinerSecret)
78
79 rawGroupContext, err := marshal(ctx)
80 if err != nil {
81 return nil, err
82 }
83 return cs.expandWithLabel(extracted, []byte("epoch"), rawGroupContext, uint16(kdf.ExtractSize()))
84 }
85
86 func (ctx *groupContext) signConfirmationTag(epochSecret []byte) ([]byte, error) {
87 cs := ctx.cipherSuite
88
89 confirmationKey, err := cs.deriveSecret(epochSecret, secretLabelConfirm)
90 if err != nil {
91 return nil, err
92 }
93
94 confirmationTag := cs.signMAC(confirmationKey, ctx.confirmedTranscriptHash)
95 return confirmationTag, nil
96 }
97
98 func extractWelcomeSecret(cs CipherSuite, joinerSecret, pskSecret []byte) ([]byte, error) {
99 _, kdf, _ := cs.hpke().Params()
100
101 if pskSecret == nil {
102 pskSecret = make([]byte, kdf.ExtractSize())
103 }
104 extracted := kdf.Extract(pskSecret, joinerSecret)
105
106 return cs.deriveSecret(extracted, []byte("welcome"))
107 }
108
109 func deriveExporter(cs CipherSuite, exporterSecret, label, context []byte, length uint16) ([]byte, error) {
110 derived, err := cs.deriveSecret(exporterSecret, label)
111 if err != nil {
112 return nil, err
113 }
114
115 h := cs.hash().New()
116 h.Write(context)
117
118 return cs.expandWithLabel(derived, []byte("exported"), h.Sum(nil), length)
119 }
120
121 var (
122 secretLabelInit = []byte("init")
123 secretLabelSenderData = []byte("sender data")
124 secretLabelEncryption = []byte("encryption")
125 secretLabelExporter = []byte("exporter")
126 secretLabelExternal = []byte("external")
127 secretLabelConfirm = []byte("confirm")
128 secretLabelMembership = []byte("membership")
129 secretLabelResumption = []byte("resumption")
130 secretLabelAuthentication = []byte("authentication")
131 )
132
133 type confirmedTranscriptHashInput struct {
134 wireFormat wireFormat
135 content framedContent
136 signature []byte
137 }
138
139 func (input *confirmedTranscriptHashInput) marshal(b *cryptobyte.Builder) {
140 if input.content.contentType != contentTypeCommit {
141 b.SetError(fmt.Errorf("mls: confirmedTranscriptHashInput can only contain contentTypeCommit"))
142 return
143 }
144 input.wireFormat.marshal(b)
145 input.content.marshal(b)
146 writeOpaqueVec(b, input.signature)
147 }
148
149 func (input *confirmedTranscriptHashInput) hash(cs CipherSuite, interimTranscriptHashBefore []byte) ([]byte, error) {
150 rawInput, err := marshal(input)
151 if err != nil {
152 return nil, err
153 }
154
155 h := cs.hash().New()
156 h.Write(interimTranscriptHashBefore)
157 h.Write(rawInput)
158 return h.Sum(nil), nil
159 }
160
161 func nextInterimTranscriptHash(cs CipherSuite, confirmedTranscriptHash, confirmationTag []byte) ([]byte, error) {
162 var b cryptobyte.Builder
163 writeOpaqueVec(&b, confirmationTag)
164 rawInput, err := b.Bytes()
165 if err != nil {
166 return nil, err
167 }
168
169 h := cs.hash().New()
170 h.Write(confirmedTranscriptHash)
171 h.Write(rawInput)
172 return h.Sum(nil), nil
173 }
174
175 type pskType uint8
176
177 const (
178 pskTypeExternal pskType = 1
179 pskTypeResumption pskType = 2
180 )
181
182 func (t *pskType) unmarshal(s *cryptobyte.String) error {
183 if !s.ReadUint8((*uint8)(t)) {
184 return io.ErrUnexpectedEOF
185 }
186 switch *t {
187 case pskTypeExternal, pskTypeResumption:
188 return nil
189 default:
190 return fmt.Errorf("mls: invalid PSK type %d", *t)
191 }
192 }
193
194 func (t pskType) marshal(b *cryptobyte.Builder) {
195 b.AddUint8(uint8(t))
196 }
197
198 type resumptionPSKUsage uint8
199
200 const (
201 resumptionPSKUsageApplication resumptionPSKUsage = 1
202 resumptionPSKUsageReinit resumptionPSKUsage = 2
203 resumptionPSKUsageBranch resumptionPSKUsage = 3
204 )
205
206 func (usage *resumptionPSKUsage) unmarshal(s *cryptobyte.String) error {
207 if !s.ReadUint8((*uint8)(usage)) {
208 return io.ErrUnexpectedEOF
209 }
210 switch *usage {
211 case resumptionPSKUsageApplication, resumptionPSKUsageReinit, resumptionPSKUsageBranch:
212 return nil
213 default:
214 return fmt.Errorf("mls: invalid resumption PSK usage %d", *usage)
215 }
216 }
217
218 func (usage resumptionPSKUsage) marshal(b *cryptobyte.Builder) {
219 b.AddUint8(uint8(usage))
220 }
221
222 type preSharedKeyID struct {
223 pskType pskType
224
225 // for pskTypeExternal
226 pskID []byte
227
228 // for pskTypeResumption
229 usage resumptionPSKUsage
230 pskGroupID GroupID
231 pskEpoch uint64
232
233 pskNonce []byte
234 }
235
236 func (id *preSharedKeyID) unmarshal(s *cryptobyte.String) error {
237 *id = preSharedKeyID{}
238
239 if err := id.pskType.unmarshal(s); err != nil {
240 return err
241 }
242
243 switch id.pskType {
244 case pskTypeExternal:
245 if !readOpaqueVec(s, &id.pskID) {
246 return io.ErrUnexpectedEOF
247 }
248 case pskTypeResumption:
249 if err := id.usage.unmarshal(s); err != nil {
250 return err
251 }
252 if !readOpaqueVec(s, (*[]byte)(&id.pskGroupID)) || !s.ReadUint64(&id.pskEpoch) {
253 return io.ErrUnexpectedEOF
254 }
255 default:
256 panic("unreachable")
257 }
258
259 if !readOpaqueVec(s, &id.pskNonce) {
260 return io.ErrUnexpectedEOF
261 }
262
263 return nil
264 }
265
266 func (id *preSharedKeyID) marshal(b *cryptobyte.Builder) {
267 id.pskType.marshal(b)
268 switch id.pskType {
269 case pskTypeExternal:
270 writeOpaqueVec(b, id.pskID)
271 case pskTypeResumption:
272 id.usage.marshal(b)
273 writeOpaqueVec(b, []byte(id.pskGroupID))
274 b.AddUint64(id.pskEpoch)
275 default:
276 panic("unreachable")
277 }
278 writeOpaqueVec(b, id.pskNonce)
279 }
280
281 func extractPSKSecret(cs CipherSuite, pskIDs []preSharedKeyID, psks [][]byte) ([]byte, error) {
282 if len(pskIDs) != len(psks) {
283 return nil, fmt.Errorf("mls: got %v PSK IDs and %v PSKs, want same number", len(pskIDs), len(psks))
284 }
285
286 _, kdf, _ := cs.hpke().Params()
287 zero := make([]byte, kdf.ExtractSize())
288
289 pskSecret := zero
290 for i := range pskIDs {
291 pskExtracted := kdf.Extract(psks[i], zero)
292
293 pskLabel := pskLabel{
294 id: pskIDs[i],
295 index: uint16(i),
296 count: uint16(len(pskIDs)),
297 }
298 rawPSKLabel, err := marshal(&pskLabel)
299 if err != nil {
300 return nil, err
301 }
302
303 pskInput, err := cs.expandWithLabel(pskExtracted, []byte("derived psk"), rawPSKLabel, uint16(kdf.ExtractSize()))
304 if err != nil {
305 return nil, err
306 }
307
308 pskSecret = kdf.Extract(pskSecret, pskInput)
309 }
310
311 return pskSecret, nil
312 }
313
314 type pskLabel struct {
315 id preSharedKeyID
316 index uint16
317 count uint16
318 }
319
320 func (label *pskLabel) marshal(b *cryptobyte.Builder) {
321 label.id.marshal(b)
322 b.AddUint16(label.index)
323 b.AddUint16(label.count)
324 }
325