proposal.mx raw

   1  package mls
   2  
   3  // MLS proposals (RFC 9420 §12).
   4  
   5  import "errors"
   6  
   7  var (
   8  	errInvalidProposalType    = errors.New("mls: invalid proposal type")
   9  	errDupAddKey              = errors.New("mls: duplicate add proposal signature key")
  10  	errUpdateByCommitter      = errors.New("mls: update proposal from committer")
  11  	errDupUpdateOrRemove      = errors.New("mls: duplicate update/remove on same leaf")
  12  	errRemoveCommitter        = errors.New("mls: remove proposal targets committer")
  13  	errDupPSK                 = errors.New("mls: duplicate PSK proposal")
  14  	errDupGroupContextExts    = errors.New("mls: duplicate group context extensions proposal")
  15  	errReinitWithOther        = errors.New("mls: reinit with other proposals")
  16  	errExternalInitNotAllowed = errors.New("mls: external init not allowed")
  17  )
  18  
  19  // --- ProposalType ---
  20  
  21  type proposalType uint16
  22  
  23  const (
  24  	proposalTypeAdd                    proposalType = 0x0001
  25  	proposalTypeUpdate                 proposalType = 0x0002
  26  	proposalTypeRemove                 proposalType = 0x0003
  27  	proposalTypePSK                    proposalType = 0x0004
  28  	proposalTypeReinit                 proposalType = 0x0005
  29  	proposalTypeExternalInit           proposalType = 0x0006
  30  	proposalTypeGroupContextExtensions proposalType = 0x0007
  31  )
  32  
  33  func (t *proposalType) unmarshal(r *Reader) error {
  34  	v, ok := r.readUint16()
  35  	if !ok {
  36  		return errUnexpectedEOF
  37  	}
  38  	*t = proposalType(v)
  39  	switch *t {
  40  	case proposalTypeAdd, proposalTypeUpdate, proposalTypeRemove,
  41  		proposalTypePSK, proposalTypeReinit, proposalTypeExternalInit,
  42  		proposalTypeGroupContextExtensions:
  43  		return nil
  44  	default:
  45  		return errInvalidProposalType
  46  	}
  47  }
  48  
  49  func (t proposalType) marshal(w *Writer) {
  50  	w.addUint16(uint16(t))
  51  }
  52  
  53  // --- Proposal ---
  54  
  55  type proposal struct {
  56  	proposalType           proposalType
  57  	add                    *add
  58  	update                 *update
  59  	remove                 *remove
  60  	preSharedKey           *preSharedKey
  61  	reInit                 *reInit
  62  	externalInit           *externalInit
  63  	groupContextExtensions *groupContextExtensions
  64  }
  65  
  66  func (prop *proposal) unmarshal(r *Reader) error {
  67  	*prop = proposal{}
  68  	if err := prop.proposalType.unmarshal(r); err != nil {
  69  		return err
  70  	}
  71  	switch prop.proposalType {
  72  	case proposalTypeAdd:
  73  		prop.add = &add{}
  74  		return prop.add.unmarshal(r)
  75  	case proposalTypeUpdate:
  76  		prop.update = &update{}
  77  		return prop.update.unmarshal(r)
  78  	case proposalTypeRemove:
  79  		prop.remove = &remove{}
  80  		return prop.remove.unmarshal(r)
  81  	case proposalTypePSK:
  82  		prop.preSharedKey = &preSharedKey{}
  83  		return prop.preSharedKey.unmarshal(r)
  84  	case proposalTypeReinit:
  85  		prop.reInit = &reInit{}
  86  		return prop.reInit.unmarshal(r)
  87  	case proposalTypeExternalInit:
  88  		prop.externalInit = &externalInit{}
  89  		return prop.externalInit.unmarshal(r)
  90  	case proposalTypeGroupContextExtensions:
  91  		prop.groupContextExtensions = &groupContextExtensions{}
  92  		return prop.groupContextExtensions.unmarshal(r)
  93  	default:
  94  		panic("unreachable")
  95  	}
  96  }
  97  
  98  func (prop *proposal) marshal(w *Writer) {
  99  	prop.proposalType.marshal(w)
 100  	switch prop.proposalType {
 101  	case proposalTypeAdd:
 102  		prop.add.marshal(w)
 103  	case proposalTypeUpdate:
 104  		prop.update.marshal(w)
 105  	case proposalTypeRemove:
 106  		prop.remove.marshal(w)
 107  	case proposalTypePSK:
 108  		prop.preSharedKey.marshal(w)
 109  	case proposalTypeReinit:
 110  		prop.reInit.marshal(w)
 111  	case proposalTypeExternalInit:
 112  		prop.externalInit.marshal(w)
 113  	case proposalTypeGroupContextExtensions:
 114  		prop.groupContextExtensions.marshal(w)
 115  	default:
 116  		panic("unreachable")
 117  	}
 118  }
 119  
 120  // --- Proposal sub-types ---
 121  
 122  type add struct {
 123  	keyPackage KeyPackage
 124  }
 125  
 126  func (a *add) unmarshal(r *Reader) error {
 127  	*a = add{}
 128  	return a.keyPackage.unmarshal(r)
 129  }
 130  
 131  func (a *add) marshal(w *Writer) {
 132  	a.keyPackage.marshal(w)
 133  }
 134  
 135  type update struct {
 136  	leafNode leafNode
 137  }
 138  
 139  func (upd *update) unmarshal(r *Reader) error {
 140  	*upd = update{}
 141  	return upd.leafNode.unmarshal(r)
 142  }
 143  
 144  func (upd *update) marshal(w *Writer) {
 145  	upd.leafNode.marshal(w)
 146  }
 147  
 148  type remove struct {
 149  	removed leafIndex
 150  }
 151  
 152  func (rm *remove) unmarshal(r *Reader) error {
 153  	*rm = remove{}
 154  	v, ok := r.readUint32()
 155  	if !ok {
 156  		return errUnexpectedEOF
 157  	}
 158  	rm.removed = leafIndex(v)
 159  	return nil
 160  }
 161  
 162  func (rm *remove) marshal(w *Writer) {
 163  	w.addUint32(uint32(rm.removed))
 164  }
 165  
 166  type preSharedKey struct {
 167  	psk preSharedKeyID
 168  }
 169  
 170  func (psk *preSharedKey) unmarshal(r *Reader) error {
 171  	*psk = preSharedKey{}
 172  	return psk.psk.unmarshal(r)
 173  }
 174  
 175  func (psk *preSharedKey) marshal(w *Writer) {
 176  	psk.psk.marshal(w)
 177  }
 178  
 179  type reInit struct {
 180  	groupID     GroupID
 181  	version     protocolVersion
 182  	cipherSuite CipherSuite
 183  	extensions  []extension
 184  }
 185  
 186  func (ri *reInit) unmarshal(r *Reader) error {
 187  	*ri = reInit{}
 188  	var ok bool
 189  	ri.groupID, ok = r.readOpaqueVec()
 190  	if !ok {
 191  		return errUnexpectedEOF
 192  	}
 193  	v, ok := r.readUint16()
 194  	if !ok {
 195  		return errUnexpectedEOF
 196  	}
 197  	ri.version = protocolVersion(v)
 198  	v, ok = r.readUint16()
 199  	if !ok {
 200  		return errUnexpectedEOF
 201  	}
 202  	ri.cipherSuite = CipherSuite(v)
 203  	exts, err := unmarshalExtensionVec(r)
 204  	if err != nil {
 205  		return err
 206  	}
 207  	ri.extensions = exts
 208  	return nil
 209  }
 210  
 211  func (ri *reInit) marshal(w *Writer) {
 212  	w.writeOpaqueVec([]byte(ri.groupID))
 213  	w.addUint16(uint16(ri.version))
 214  	w.addUint16(uint16(ri.cipherSuite))
 215  	marshalExtensionVec(w, ri.extensions)
 216  }
 217  
 218  type externalInit struct {
 219  	kemOutput []byte
 220  }
 221  
 222  func (ei *externalInit) unmarshal(r *Reader) error {
 223  	*ei = externalInit{}
 224  	var ok bool
 225  	ei.kemOutput, ok = r.readOpaqueVec()
 226  	if !ok {
 227  		return errUnexpectedEOF
 228  	}
 229  	return nil
 230  }
 231  
 232  func (ei *externalInit) marshal(w *Writer) {
 233  	w.writeOpaqueVec(ei.kemOutput)
 234  }
 235  
 236  type groupContextExtensions struct {
 237  	extensions []extension
 238  }
 239  
 240  func (exts *groupContextExtensions) unmarshal(r *Reader) error {
 241  	*exts = groupContextExtensions{}
 242  	l, err := unmarshalExtensionVec(r)
 243  	if err != nil {
 244  		return err
 245  	}
 246  	exts.extensions = l
 247  	return nil
 248  }
 249  
 250  func (exts *groupContextExtensions) marshal(w *Writer) {
 251  	marshalExtensionVec(w, exts.extensions)
 252  }
 253  
 254  // --- ProposalOrRef ---
 255  
 256  type proposalOrRefType uint8
 257  
 258  const (
 259  	proposalOrRefTypeProposal  proposalOrRefType = 1
 260  	proposalOrRefTypeReference proposalOrRefType = 2
 261  )
 262  
 263  func (t *proposalOrRefType) unmarshal(r *Reader) error {
 264  	b, ok := r.readByte()
 265  	if !ok {
 266  		return errUnexpectedEOF
 267  	}
 268  	*t = proposalOrRefType(b)
 269  	switch *t {
 270  	case proposalOrRefTypeProposal, proposalOrRefTypeReference:
 271  		return nil
 272  	default:
 273  		return errInvalidProposalOrRefType
 274  	}
 275  }
 276  
 277  func (t proposalOrRefType) marshal(w *Writer) {
 278  	w.addByte(byte(t))
 279  }
 280  
 281  type proposalRef []byte
 282  
 283  func (ref proposalRef) equal(other proposalRef) bool {
 284  	return bytesEqual([]byte(ref), []byte(other))
 285  }
 286  
 287  type proposalOrRef struct {
 288  	typ       proposalOrRefType
 289  	proposal  *proposal
 290  	reference proposalRef
 291  }
 292  
 293  func (por *proposalOrRef) unmarshal(r *Reader) error {
 294  	*por = proposalOrRef{}
 295  	if err := por.typ.unmarshal(r); err != nil {
 296  		return err
 297  	}
 298  	switch por.typ {
 299  	case proposalOrRefTypeProposal:
 300  		por.proposal = &proposal{}
 301  		return por.proposal.unmarshal(r)
 302  	case proposalOrRefTypeReference:
 303  		var ok bool
 304  		por.reference, ok = r.readOpaqueVec()
 305  		if !ok {
 306  			return errUnexpectedEOF
 307  		}
 308  		return nil
 309  	default:
 310  		panic("unreachable")
 311  	}
 312  }
 313  
 314  func (por *proposalOrRef) marshal(w *Writer) {
 315  	por.typ.marshal(w)
 316  	switch por.typ {
 317  	case proposalOrRefTypeProposal:
 318  		por.proposal.marshal(w)
 319  	case proposalOrRefTypeReference:
 320  		w.writeOpaqueVec([]byte(por.reference))
 321  	default:
 322  		panic("unreachable")
 323  	}
 324  }
 325  
 326  // --- Proposal list validation (RFC 9420 §12.2) ---
 327  
 328  func verifyProposalList(proposals []proposal, senders []leafIndex, committer leafIndex) error {
 329  	if len(proposals) != len(senders) {
 330  		panic("unreachable")
 331  	}
 332  
 333  	addKeys := map[string]struct{}{}
 334  	updateOrRemove := map[leafIndex]struct{}{}
 335  	pskKeys := map[string]struct{}{}
 336  	hasGroupCtxExts := false
 337  
 338  	for i, prop := range proposals {
 339  		snd := senders[i]
 340  		switch prop.proposalType {
 341  		case proposalTypeAdd:
 342  			k := string(prop.add.keyPackage.leafNode.signatureKey)
 343  			if _, dup := addKeys[k]; dup {
 344  				return errDupAddKey
 345  			}
 346  			addKeys[k] = struct{}{}
 347  		case proposalTypeUpdate:
 348  			if snd == committer {
 349  				return errUpdateByCommitter
 350  			}
 351  			if _, dup := updateOrRemove[snd]; dup {
 352  				return errDupUpdateOrRemove
 353  			}
 354  			updateOrRemove[snd] = struct{}{}
 355  		case proposalTypeRemove:
 356  			if prop.remove.removed == committer {
 357  				return errRemoveCommitter
 358  			}
 359  			if _, dup := updateOrRemove[prop.remove.removed]; dup {
 360  				return errDupUpdateOrRemove
 361  			}
 362  			updateOrRemove[prop.remove.removed] = struct{}{}
 363  		case proposalTypePSK:
 364  			raw, err := marshalRaw(&prop.preSharedKey.psk)
 365  			if err != nil {
 366  				return err
 367  			}
 368  			k := string(raw)
 369  			if _, dup := pskKeys[k]; dup {
 370  				return errDupPSK
 371  			}
 372  			pskKeys[k] = struct{}{}
 373  		case proposalTypeGroupContextExtensions:
 374  			if hasGroupCtxExts {
 375  				return errDupGroupContextExts
 376  			}
 377  			hasGroupCtxExts = true
 378  		case proposalTypeReinit:
 379  			if len(proposals) > 1 {
 380  				return errReinitWithOther
 381  			}
 382  		case proposalTypeExternalInit:
 383  			return errExternalInitNotAllowed
 384  		}
 385  	}
 386  	return nil
 387  }
 388  
 389  func proposalListNeedsPath(proposals []proposal) bool {
 390  	if len(proposals) == 0 {
 391  		return true
 392  	}
 393  	for _, prop := range proposals {
 394  		switch prop.proposalType {
 395  		case proposalTypeUpdate, proposalTypeRemove,
 396  			proposalTypeExternalInit, proposalTypeGroupContextExtensions:
 397  			return true
 398  		}
 399  	}
 400  	return false
 401  }
 402