key_schedule.mx raw

   1  package mls
   2  
   3  // MLS key schedule types (RFC 9420 §8).
   4  // Data types and serialization — crypto operations go in key_schedule_crypto.mx.
   5  
   6  import "errors"
   7  
   8  var (
   9  	errInvalidPSKType    = errors.New("mls: invalid PSK type")
  10  	errInvalidPSKUsage   = errors.New("mls: invalid resumption PSK usage")
  11  	errInvalidProposalOrRefType = errors.New("mls: invalid proposal or ref type")
  12  )
  13  
  14  // --- CipherSuite ---
  15  
  16  type CipherSuite uint16
  17  
  18  const (
  19  	// MLS_128_DHKEMP256_AES128GCM_SHA256_P256
  20  	CipherSuite0x0001 CipherSuite = 0x0001
  21  	// MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519
  22  	CipherSuite0x0003 CipherSuite = 0x0003
  23  )
  24  
  25  // --- GroupContext ---
  26  
  27  type groupContext struct {
  28  	version                 protocolVersion
  29  	cipherSuite             CipherSuite
  30  	groupID                 GroupID
  31  	epoch                   uint64
  32  	treeHash                []byte
  33  	confirmedTranscriptHash []byte
  34  	extensions              []extension
  35  }
  36  
  37  func (ctx *groupContext) unmarshal(r *Reader) error {
  38  	*ctx = groupContext{}
  39  
  40  	v, ok := r.readUint16()
  41  	if !ok {
  42  		return errUnexpectedEOF
  43  	}
  44  	ctx.version = protocolVersion(v)
  45  	if ctx.version != protocolVersionMLS10 {
  46  		return errInvalidVersion
  47  	}
  48  
  49  	v, ok = r.readUint16()
  50  	if !ok {
  51  		return errUnexpectedEOF
  52  	}
  53  	ctx.cipherSuite = CipherSuite(v)
  54  
  55  	ctx.groupID, ok = r.readOpaqueVec()
  56  	if !ok {
  57  		return errUnexpectedEOF
  58  	}
  59  	ctx.epoch, ok = r.readUint64()
  60  	if !ok {
  61  		return errUnexpectedEOF
  62  	}
  63  	ctx.treeHash, ok = r.readOpaqueVec()
  64  	if !ok {
  65  		return errUnexpectedEOF
  66  	}
  67  	ctx.confirmedTranscriptHash, ok = r.readOpaqueVec()
  68  	if !ok {
  69  		return errUnexpectedEOF
  70  	}
  71  
  72  	exts, err := unmarshalExtensionVec(r)
  73  	if err != nil {
  74  		return err
  75  	}
  76  	ctx.extensions = exts
  77  	return nil
  78  }
  79  
  80  func (ctx *groupContext) marshal(w *Writer) {
  81  	w.addUint16(uint16(ctx.version))
  82  	w.addUint16(uint16(ctx.cipherSuite))
  83  	w.writeOpaqueVec([]byte(ctx.groupID))
  84  	w.addUint64(ctx.epoch)
  85  	w.writeOpaqueVec(ctx.treeHash)
  86  	w.writeOpaqueVec(ctx.confirmedTranscriptHash)
  87  	marshalExtensionVec(w, ctx.extensions)
  88  }
  89  
  90  // --- Secret labels ---
  91  
  92  var (
  93  	secretLabelInit           = []byte("init")
  94  	secretLabelSenderData     = []byte("sender data")
  95  	secretLabelEncryption     = []byte("encryption")
  96  	secretLabelExporter       = []byte("exporter")
  97  	secretLabelExternal       = []byte("external")
  98  	secretLabelConfirm        = []byte("confirm")
  99  	secretLabelMembership     = []byte("membership")
 100  	secretLabelResumption     = []byte("resumption")
 101  	secretLabelAuthentication = []byte("authentication")
 102  )
 103  
 104  // --- ConfirmedTranscriptHashInput ---
 105  
 106  type confirmedTranscriptHashInput struct {
 107  	wireFormat wireFormat
 108  	content    framedContent
 109  	signature  []byte
 110  }
 111  
 112  func (input *confirmedTranscriptHashInput) marshal(w *Writer) {
 113  	input.wireFormat.marshal(w)
 114  	input.content.marshal(w)
 115  	w.writeOpaqueVec(input.signature)
 116  }
 117  
 118  // --- PSK types ---
 119  
 120  type pskType uint8
 121  
 122  const (
 123  	pskTypeExternal   pskType = 1
 124  	pskTypeResumption pskType = 2
 125  )
 126  
 127  func (t *pskType) unmarshal(r *Reader) error {
 128  	b, ok := r.readByte()
 129  	if !ok {
 130  		return errUnexpectedEOF
 131  	}
 132  	*t = pskType(b)
 133  	switch *t {
 134  	case pskTypeExternal, pskTypeResumption:
 135  		return nil
 136  	default:
 137  		return errInvalidPSKType
 138  	}
 139  }
 140  
 141  func (t pskType) marshal(w *Writer) {
 142  	w.addByte(byte(t))
 143  }
 144  
 145  type resumptionPSKUsage uint8
 146  
 147  const (
 148  	resumptionPSKUsageApplication resumptionPSKUsage = 1
 149  	resumptionPSKUsageReinit      resumptionPSKUsage = 2
 150  	resumptionPSKUsageBranch      resumptionPSKUsage = 3
 151  )
 152  
 153  func (usage *resumptionPSKUsage) unmarshal(r *Reader) error {
 154  	b, ok := r.readByte()
 155  	if !ok {
 156  		return errUnexpectedEOF
 157  	}
 158  	*usage = resumptionPSKUsage(b)
 159  	switch *usage {
 160  	case resumptionPSKUsageApplication, resumptionPSKUsageReinit, resumptionPSKUsageBranch:
 161  		return nil
 162  	default:
 163  		return errInvalidPSKUsage
 164  	}
 165  }
 166  
 167  func (usage resumptionPSKUsage) marshal(w *Writer) {
 168  	w.addByte(byte(usage))
 169  }
 170  
 171  // --- PreSharedKeyID ---
 172  
 173  type preSharedKeyID struct {
 174  	pskType pskType
 175  
 176  	pskID []byte // for pskTypeExternal
 177  
 178  	usage      resumptionPSKUsage // for pskTypeResumption
 179  	pskGroupID GroupID            // for pskTypeResumption
 180  	pskEpoch   uint64             // for pskTypeResumption
 181  
 182  	pskNonce []byte
 183  }
 184  
 185  func (id *preSharedKeyID) unmarshal(r *Reader) error {
 186  	*id = preSharedKeyID{}
 187  	if err := id.pskType.unmarshal(r); err != nil {
 188  		return err
 189  	}
 190  
 191  	switch id.pskType {
 192  	case pskTypeExternal:
 193  		var ok bool
 194  		id.pskID, ok = r.readOpaqueVec()
 195  		if !ok {
 196  			return errUnexpectedEOF
 197  		}
 198  	case pskTypeResumption:
 199  		if err := id.usage.unmarshal(r); err != nil {
 200  			return err
 201  		}
 202  		var ok bool
 203  		id.pskGroupID, ok = r.readOpaqueVec()
 204  		if !ok {
 205  			return errUnexpectedEOF
 206  		}
 207  		id.pskEpoch, ok = r.readUint64()
 208  		if !ok {
 209  			return errUnexpectedEOF
 210  		}
 211  	default:
 212  		panic("unreachable")
 213  	}
 214  
 215  	var ok bool
 216  	id.pskNonce, ok = r.readOpaqueVec()
 217  	if !ok {
 218  		return errUnexpectedEOF
 219  	}
 220  	return nil
 221  }
 222  
 223  func (id *preSharedKeyID) marshal(w *Writer) {
 224  	id.pskType.marshal(w)
 225  	switch id.pskType {
 226  	case pskTypeExternal:
 227  		w.writeOpaqueVec(id.pskID)
 228  	case pskTypeResumption:
 229  		id.usage.marshal(w)
 230  		w.writeOpaqueVec([]byte(id.pskGroupID))
 231  		w.addUint64(id.pskEpoch)
 232  	default:
 233  		panic("unreachable")
 234  	}
 235  	w.writeOpaqueVec(id.pskNonce)
 236  }
 237