key_schedule.go raw

   1  package mls
   2  
   3  import (
   4  	"fmt"
   5  	"io"
   6  
   7  	"golang.org/x/crypto/cryptobyte"
   8  )
   9  
  10  type groupContext struct {
  11  	version                 protocolVersion
  12  	cipherSuite             CipherSuite
  13  	groupID                 GroupID
  14  	epoch                   uint64
  15  	treeHash                []byte
  16  	confirmedTranscriptHash []byte
  17  	extensions              []extension
  18  }
  19  
  20  func (ctx *groupContext) unmarshal(s *cryptobyte.String) error {
  21  	*ctx = groupContext{}
  22  
  23  	ok := s.ReadUint16((*uint16)(&ctx.version)) &&
  24  		s.ReadUint16((*uint16)(&ctx.cipherSuite)) &&
  25  		readOpaqueVec(s, (*[]byte)(&ctx.groupID)) &&
  26  		s.ReadUint64(&ctx.epoch) &&
  27  		readOpaqueVec(s, &ctx.treeHash) &&
  28  		readOpaqueVec(s, &ctx.confirmedTranscriptHash)
  29  	if !ok {
  30  		return io.ErrUnexpectedEOF
  31  	}
  32  
  33  	if ctx.version != protocolVersionMLS10 {
  34  		return fmt.Errorf("mls: invalid protocol version %d", ctx.version)
  35  	}
  36  
  37  	exts, err := unmarshalExtensionVec(s)
  38  	if err != nil {
  39  		return err
  40  	}
  41  	ctx.extensions = exts
  42  
  43  	return nil
  44  }
  45  
  46  func (ctx *groupContext) marshal(b *cryptobyte.Builder) {
  47  	b.AddUint16(uint16(ctx.version))
  48  	b.AddUint16(uint16(ctx.cipherSuite))
  49  	writeOpaqueVec(b, []byte(ctx.groupID))
  50  	b.AddUint64(ctx.epoch)
  51  	writeOpaqueVec(b, ctx.treeHash)
  52  	writeOpaqueVec(b, ctx.confirmedTranscriptHash)
  53  	marshalExtensionVec(b, ctx.extensions)
  54  }
  55  
  56  func (ctx *groupContext) extractJoinerSecret(prevInitSecret, commitSecret []byte) ([]byte, error) {
  57  	cs := ctx.cipherSuite
  58  	_, kdf, _ := cs.hpke().Params()
  59  
  60  	extracted := kdf.Extract(commitSecret, prevInitSecret)
  61  
  62  	rawGroupContext, err := marshal(ctx)
  63  	if err != nil {
  64  		return nil, err
  65  	}
  66  	return cs.expandWithLabel(extracted, []byte("joiner"), rawGroupContext, uint16(kdf.ExtractSize()))
  67  }
  68  
  69  func (ctx *groupContext) extractEpochSecret(joinerSecret, pskSecret []byte) ([]byte, error) {
  70  	cs := ctx.cipherSuite
  71  	_, kdf, _ := cs.hpke().Params()
  72  
  73  	// TODO de-duplicate with extractWelcomeSecret
  74  	if pskSecret == nil {
  75  		pskSecret = make([]byte, kdf.ExtractSize())
  76  	}
  77  	extracted := kdf.Extract(pskSecret, joinerSecret)
  78  
  79  	rawGroupContext, err := marshal(ctx)
  80  	if err != nil {
  81  		return nil, err
  82  	}
  83  	return cs.expandWithLabel(extracted, []byte("epoch"), rawGroupContext, uint16(kdf.ExtractSize()))
  84  }
  85  
  86  func (ctx *groupContext) signConfirmationTag(epochSecret []byte) ([]byte, error) {
  87  	cs := ctx.cipherSuite
  88  
  89  	confirmationKey, err := cs.deriveSecret(epochSecret, secretLabelConfirm)
  90  	if err != nil {
  91  		return nil, err
  92  	}
  93  
  94  	confirmationTag := cs.signMAC(confirmationKey, ctx.confirmedTranscriptHash)
  95  	return confirmationTag, nil
  96  }
  97  
  98  func extractWelcomeSecret(cs CipherSuite, joinerSecret, pskSecret []byte) ([]byte, error) {
  99  	_, kdf, _ := cs.hpke().Params()
 100  
 101  	if pskSecret == nil {
 102  		pskSecret = make([]byte, kdf.ExtractSize())
 103  	}
 104  	extracted := kdf.Extract(pskSecret, joinerSecret)
 105  
 106  	return cs.deriveSecret(extracted, []byte("welcome"))
 107  }
 108  
 109  func deriveExporter(cs CipherSuite, exporterSecret, label, context []byte, length uint16) ([]byte, error) {
 110  	derived, err := cs.deriveSecret(exporterSecret, label)
 111  	if err != nil {
 112  		return nil, err
 113  	}
 114  
 115  	h := cs.hash().New()
 116  	h.Write(context)
 117  
 118  	return cs.expandWithLabel(derived, []byte("exported"), h.Sum(nil), length)
 119  }
 120  
 121  var (
 122  	secretLabelInit           = []byte("init")
 123  	secretLabelSenderData     = []byte("sender data")
 124  	secretLabelEncryption     = []byte("encryption")
 125  	secretLabelExporter       = []byte("exporter")
 126  	secretLabelExternal       = []byte("external")
 127  	secretLabelConfirm        = []byte("confirm")
 128  	secretLabelMembership     = []byte("membership")
 129  	secretLabelResumption     = []byte("resumption")
 130  	secretLabelAuthentication = []byte("authentication")
 131  )
 132  
 133  type confirmedTranscriptHashInput struct {
 134  	wireFormat wireFormat
 135  	content    framedContent
 136  	signature  []byte
 137  }
 138  
 139  func (input *confirmedTranscriptHashInput) marshal(b *cryptobyte.Builder) {
 140  	if input.content.contentType != contentTypeCommit {
 141  		b.SetError(fmt.Errorf("mls: confirmedTranscriptHashInput can only contain contentTypeCommit"))
 142  		return
 143  	}
 144  	input.wireFormat.marshal(b)
 145  	input.content.marshal(b)
 146  	writeOpaqueVec(b, input.signature)
 147  }
 148  
 149  func (input *confirmedTranscriptHashInput) hash(cs CipherSuite, interimTranscriptHashBefore []byte) ([]byte, error) {
 150  	rawInput, err := marshal(input)
 151  	if err != nil {
 152  		return nil, err
 153  	}
 154  
 155  	h := cs.hash().New()
 156  	h.Write(interimTranscriptHashBefore)
 157  	h.Write(rawInput)
 158  	return h.Sum(nil), nil
 159  }
 160  
 161  func nextInterimTranscriptHash(cs CipherSuite, confirmedTranscriptHash, confirmationTag []byte) ([]byte, error) {
 162  	var b cryptobyte.Builder
 163  	writeOpaqueVec(&b, confirmationTag)
 164  	rawInput, err := b.Bytes()
 165  	if err != nil {
 166  		return nil, err
 167  	}
 168  
 169  	h := cs.hash().New()
 170  	h.Write(confirmedTranscriptHash)
 171  	h.Write(rawInput)
 172  	return h.Sum(nil), nil
 173  }
 174  
 175  type pskType uint8
 176  
 177  const (
 178  	pskTypeExternal   pskType = 1
 179  	pskTypeResumption pskType = 2
 180  )
 181  
 182  func (t *pskType) unmarshal(s *cryptobyte.String) error {
 183  	if !s.ReadUint8((*uint8)(t)) {
 184  		return io.ErrUnexpectedEOF
 185  	}
 186  	switch *t {
 187  	case pskTypeExternal, pskTypeResumption:
 188  		return nil
 189  	default:
 190  		return fmt.Errorf("mls: invalid PSK type %d", *t)
 191  	}
 192  }
 193  
 194  func (t pskType) marshal(b *cryptobyte.Builder) {
 195  	b.AddUint8(uint8(t))
 196  }
 197  
 198  type resumptionPSKUsage uint8
 199  
 200  const (
 201  	resumptionPSKUsageApplication resumptionPSKUsage = 1
 202  	resumptionPSKUsageReinit      resumptionPSKUsage = 2
 203  	resumptionPSKUsageBranch      resumptionPSKUsage = 3
 204  )
 205  
 206  func (usage *resumptionPSKUsage) unmarshal(s *cryptobyte.String) error {
 207  	if !s.ReadUint8((*uint8)(usage)) {
 208  		return io.ErrUnexpectedEOF
 209  	}
 210  	switch *usage {
 211  	case resumptionPSKUsageApplication, resumptionPSKUsageReinit, resumptionPSKUsageBranch:
 212  		return nil
 213  	default:
 214  		return fmt.Errorf("mls: invalid resumption PSK usage %d", *usage)
 215  	}
 216  }
 217  
 218  func (usage resumptionPSKUsage) marshal(b *cryptobyte.Builder) {
 219  	b.AddUint8(uint8(usage))
 220  }
 221  
 222  type preSharedKeyID struct {
 223  	pskType pskType
 224  
 225  	// for pskTypeExternal
 226  	pskID []byte
 227  
 228  	// for pskTypeResumption
 229  	usage      resumptionPSKUsage
 230  	pskGroupID GroupID
 231  	pskEpoch   uint64
 232  
 233  	pskNonce []byte
 234  }
 235  
 236  func (id *preSharedKeyID) unmarshal(s *cryptobyte.String) error {
 237  	*id = preSharedKeyID{}
 238  
 239  	if err := id.pskType.unmarshal(s); err != nil {
 240  		return err
 241  	}
 242  
 243  	switch id.pskType {
 244  	case pskTypeExternal:
 245  		if !readOpaqueVec(s, &id.pskID) {
 246  			return io.ErrUnexpectedEOF
 247  		}
 248  	case pskTypeResumption:
 249  		if err := id.usage.unmarshal(s); err != nil {
 250  			return err
 251  		}
 252  		if !readOpaqueVec(s, (*[]byte)(&id.pskGroupID)) || !s.ReadUint64(&id.pskEpoch) {
 253  			return io.ErrUnexpectedEOF
 254  		}
 255  	default:
 256  		panic("unreachable")
 257  	}
 258  
 259  	if !readOpaqueVec(s, &id.pskNonce) {
 260  		return io.ErrUnexpectedEOF
 261  	}
 262  
 263  	return nil
 264  }
 265  
 266  func (id *preSharedKeyID) marshal(b *cryptobyte.Builder) {
 267  	id.pskType.marshal(b)
 268  	switch id.pskType {
 269  	case pskTypeExternal:
 270  		writeOpaqueVec(b, id.pskID)
 271  	case pskTypeResumption:
 272  		id.usage.marshal(b)
 273  		writeOpaqueVec(b, []byte(id.pskGroupID))
 274  		b.AddUint64(id.pskEpoch)
 275  	default:
 276  		panic("unreachable")
 277  	}
 278  	writeOpaqueVec(b, id.pskNonce)
 279  }
 280  
 281  func extractPSKSecret(cs CipherSuite, pskIDs []preSharedKeyID, psks [][]byte) ([]byte, error) {
 282  	if len(pskIDs) != len(psks) {
 283  		return nil, fmt.Errorf("mls: got %v PSK IDs and %v PSKs, want same number", len(pskIDs), len(psks))
 284  	}
 285  
 286  	_, kdf, _ := cs.hpke().Params()
 287  	zero := make([]byte, kdf.ExtractSize())
 288  
 289  	pskSecret := zero
 290  	for i := range pskIDs {
 291  		pskExtracted := kdf.Extract(psks[i], zero)
 292  
 293  		pskLabel := pskLabel{
 294  			id:    pskIDs[i],
 295  			index: uint16(i),
 296  			count: uint16(len(pskIDs)),
 297  		}
 298  		rawPSKLabel, err := marshal(&pskLabel)
 299  		if err != nil {
 300  			return nil, err
 301  		}
 302  
 303  		pskInput, err := cs.expandWithLabel(pskExtracted, []byte("derived psk"), rawPSKLabel, uint16(kdf.ExtractSize()))
 304  		if err != nil {
 305  			return nil, err
 306  		}
 307  
 308  		pskSecret = kdf.Extract(pskSecret, pskInput)
 309  	}
 310  
 311  	return pskSecret, nil
 312  }
 313  
 314  type pskLabel struct {
 315  	id    preSharedKeyID
 316  	index uint16
 317  	count uint16
 318  }
 319  
 320  func (label *pskLabel) marshal(b *cryptobyte.Builder) {
 321  	label.id.marshal(b)
 322  	b.AddUint16(label.index)
 323  	b.AddUint16(label.count)
 324  }
 325