proposal.go raw

   1  package mls
   2  
   3  import (
   4  	"bytes"
   5  	"fmt"
   6  	"io"
   7  
   8  	"golang.org/x/crypto/cryptobyte"
   9  )
  10  
  11  // http://www.iana.org/assignments/mls/mls.xhtml#mls-proposal-types
  12  type proposalType uint16
  13  
  14  const (
  15  	proposalTypeAdd                    proposalType = 0x0001
  16  	proposalTypeUpdate                 proposalType = 0x0002
  17  	proposalTypeRemove                 proposalType = 0x0003
  18  	proposalTypePSK                    proposalType = 0x0004
  19  	proposalTypeReinit                 proposalType = 0x0005
  20  	proposalTypeExternalInit           proposalType = 0x0006
  21  	proposalTypeGroupContextExtensions proposalType = 0x0007
  22  )
  23  
  24  func (t *proposalType) unmarshal(s *cryptobyte.String) error {
  25  	if !s.ReadUint16((*uint16)(t)) {
  26  		return io.ErrUnexpectedEOF
  27  	}
  28  	switch *t {
  29  	case proposalTypeAdd, proposalTypeUpdate, proposalTypeRemove, proposalTypePSK, proposalTypeReinit, proposalTypeExternalInit, proposalTypeGroupContextExtensions:
  30  		return nil
  31  	default:
  32  		return fmt.Errorf("mls: invalid proposal type %d", *t)
  33  	}
  34  }
  35  
  36  func (t proposalType) marshal(b *cryptobyte.Builder) {
  37  	b.AddUint16(uint16(t))
  38  }
  39  
  40  type proposal struct {
  41  	proposalType           proposalType
  42  	add                    *add                    // for proposalTypeAdd
  43  	update                 *update                 // for proposalTypeUpdate
  44  	remove                 *remove                 // for proposalTypeRemove
  45  	preSharedKey           *preSharedKey           // for proposalTypePSK
  46  	reInit                 *reInit                 // for proposalTypeReinit
  47  	externalInit           *externalInit           // for proposalTypeExternalInit
  48  	groupContextExtensions *groupContextExtensions // for proposalTypeGroupContextExtensions
  49  }
  50  
  51  func (prop *proposal) unmarshal(s *cryptobyte.String) error {
  52  	*prop = proposal{}
  53  	if err := prop.proposalType.unmarshal(s); err != nil {
  54  		return err
  55  	}
  56  	switch prop.proposalType {
  57  	case proposalTypeAdd:
  58  		prop.add = new(add)
  59  		return prop.add.unmarshal(s)
  60  	case proposalTypeUpdate:
  61  		prop.update = new(update)
  62  		return prop.update.unmarshal(s)
  63  	case proposalTypeRemove:
  64  		prop.remove = new(remove)
  65  		return prop.remove.unmarshal(s)
  66  	case proposalTypePSK:
  67  		prop.preSharedKey = new(preSharedKey)
  68  		return prop.preSharedKey.unmarshal(s)
  69  	case proposalTypeReinit:
  70  		prop.reInit = new(reInit)
  71  		return prop.reInit.unmarshal(s)
  72  	case proposalTypeExternalInit:
  73  		prop.externalInit = new(externalInit)
  74  		return prop.externalInit.unmarshal(s)
  75  	case proposalTypeGroupContextExtensions:
  76  		prop.groupContextExtensions = new(groupContextExtensions)
  77  		return prop.groupContextExtensions.unmarshal(s)
  78  	default:
  79  		panic("unreachable")
  80  	}
  81  }
  82  
  83  func (prop *proposal) marshal(b *cryptobyte.Builder) {
  84  	prop.proposalType.marshal(b)
  85  	switch prop.proposalType {
  86  	case proposalTypeAdd:
  87  		prop.add.marshal(b)
  88  	case proposalTypeUpdate:
  89  		prop.update.marshal(b)
  90  	case proposalTypeRemove:
  91  		prop.remove.marshal(b)
  92  	case proposalTypePSK:
  93  		prop.preSharedKey.marshal(b)
  94  	case proposalTypeReinit:
  95  		prop.reInit.marshal(b)
  96  	case proposalTypeExternalInit:
  97  		prop.externalInit.marshal(b)
  98  	case proposalTypeGroupContextExtensions:
  99  		prop.groupContextExtensions.marshal(b)
 100  	default:
 101  		panic("unreachable")
 102  	}
 103  }
 104  
 105  type add struct {
 106  	keyPackage KeyPackage
 107  }
 108  
 109  func (a *add) unmarshal(s *cryptobyte.String) error {
 110  	*a = add{}
 111  	return a.keyPackage.unmarshal(s)
 112  }
 113  
 114  func (a *add) marshal(b *cryptobyte.Builder) {
 115  	a.keyPackage.marshal(b)
 116  }
 117  
 118  type update struct {
 119  	leafNode leafNode
 120  }
 121  
 122  func (upd *update) unmarshal(s *cryptobyte.String) error {
 123  	*upd = update{}
 124  	return upd.leafNode.unmarshal(s)
 125  }
 126  
 127  func (upd *update) marshal(b *cryptobyte.Builder) {
 128  	upd.leafNode.marshal(b)
 129  }
 130  
 131  type remove struct {
 132  	removed leafIndex
 133  }
 134  
 135  func (rm *remove) unmarshal(s *cryptobyte.String) error {
 136  	*rm = remove{}
 137  	if !s.ReadUint32((*uint32)(&rm.removed)) {
 138  		return io.ErrUnexpectedEOF
 139  	}
 140  	return nil
 141  }
 142  
 143  func (rm *remove) marshal(b *cryptobyte.Builder) {
 144  	b.AddUint32(uint32(rm.removed))
 145  }
 146  
 147  type preSharedKey struct {
 148  	psk preSharedKeyID
 149  }
 150  
 151  func (psk *preSharedKey) unmarshal(s *cryptobyte.String) error {
 152  	*psk = preSharedKey{}
 153  	return psk.psk.unmarshal(s)
 154  }
 155  
 156  func (psk *preSharedKey) marshal(b *cryptobyte.Builder) {
 157  	psk.psk.marshal(b)
 158  }
 159  
 160  type reInit struct {
 161  	groupID     GroupID
 162  	version     protocolVersion
 163  	cipherSuite CipherSuite
 164  	extensions  []extension
 165  }
 166  
 167  func (ri *reInit) unmarshal(s *cryptobyte.String) error {
 168  	*ri = reInit{}
 169  
 170  	if !readOpaqueVec(s, (*[]byte)(&ri.groupID)) || !s.ReadUint16((*uint16)(&ri.version)) || !s.ReadUint16((*uint16)(&ri.cipherSuite)) {
 171  		return io.ErrUnexpectedEOF
 172  	}
 173  
 174  	exts, err := unmarshalExtensionVec(s)
 175  	if err != nil {
 176  		return err
 177  	}
 178  	ri.extensions = exts
 179  
 180  	return nil
 181  }
 182  
 183  func (ri *reInit) marshal(b *cryptobyte.Builder) {
 184  	writeOpaqueVec(b, []byte(ri.groupID))
 185  	b.AddUint16(uint16(ri.version))
 186  	b.AddUint16(uint16(ri.cipherSuite))
 187  	marshalExtensionVec(b, ri.extensions)
 188  }
 189  
 190  type externalInit struct {
 191  	kemOutput []byte
 192  }
 193  
 194  func (ei *externalInit) unmarshal(s *cryptobyte.String) error {
 195  	*ei = externalInit{}
 196  	if !readOpaqueVec(s, &ei.kemOutput) {
 197  		return io.ErrUnexpectedEOF
 198  	}
 199  	return nil
 200  }
 201  
 202  func (ei *externalInit) marshal(b *cryptobyte.Builder) {
 203  	writeOpaqueVec(b, ei.kemOutput)
 204  }
 205  
 206  type groupContextExtensions struct {
 207  	extensions []extension
 208  }
 209  
 210  func (exts *groupContextExtensions) unmarshal(s *cryptobyte.String) error {
 211  	*exts = groupContextExtensions{}
 212  
 213  	l, err := unmarshalExtensionVec(s)
 214  	if err != nil {
 215  		return err
 216  	}
 217  	exts.extensions = l
 218  
 219  	return nil
 220  }
 221  
 222  func (exts *groupContextExtensions) marshal(b *cryptobyte.Builder) {
 223  	marshalExtensionVec(b, exts.extensions)
 224  }
 225  
 226  type proposalOrRefType uint8
 227  
 228  const (
 229  	proposalOrRefTypeProposal  proposalOrRefType = 1
 230  	proposalOrRefTypeReference proposalOrRefType = 2
 231  )
 232  
 233  func (t *proposalOrRefType) unmarshal(s *cryptobyte.String) error {
 234  	if !s.ReadUint8((*uint8)(t)) {
 235  		return io.ErrUnexpectedEOF
 236  	}
 237  	switch *t {
 238  	case proposalOrRefTypeProposal, proposalOrRefTypeReference:
 239  		return nil
 240  	default:
 241  		return fmt.Errorf("mls: invalid proposal or ref type %d", *t)
 242  	}
 243  }
 244  
 245  func (t proposalOrRefType) marshal(b *cryptobyte.Builder) {
 246  	b.AddUint8(uint8(t))
 247  }
 248  
 249  type proposalRef []byte
 250  
 251  func (ref proposalRef) Equal(other proposalRef) bool {
 252  	return bytes.Equal([]byte(ref), []byte(other))
 253  }
 254  
 255  type proposalOrRef struct {
 256  	typ       proposalOrRefType
 257  	proposal  *proposal   // for proposalOrRefTypeProposal
 258  	reference proposalRef // for proposalOrRefTypeReference
 259  }
 260  
 261  func (propOrRef *proposalOrRef) unmarshal(s *cryptobyte.String) error {
 262  	*propOrRef = proposalOrRef{}
 263  
 264  	if err := propOrRef.typ.unmarshal(s); err != nil {
 265  		return err
 266  	}
 267  
 268  	switch propOrRef.typ {
 269  	case proposalOrRefTypeProposal:
 270  		propOrRef.proposal = new(proposal)
 271  		return propOrRef.proposal.unmarshal(s)
 272  	case proposalOrRefTypeReference:
 273  		if !readOpaqueVec(s, (*[]byte)(&propOrRef.reference)) {
 274  			return io.ErrUnexpectedEOF
 275  		}
 276  		return nil
 277  	default:
 278  		panic("unreachable")
 279  	}
 280  }
 281  
 282  func (propOrRef *proposalOrRef) marshal(b *cryptobyte.Builder) {
 283  	propOrRef.typ.marshal(b)
 284  	switch propOrRef.typ {
 285  	case proposalOrRefTypeProposal:
 286  		propOrRef.proposal.marshal(b)
 287  	case proposalOrRefTypeReference:
 288  		writeOpaqueVec(b, []byte(propOrRef.reference))
 289  	default:
 290  		panic("unreachable")
 291  	}
 292  }
 293  
 294  // verifyProposalList ensures that a list of proposals passes the checks for a
 295  // regular commit described in section 12.2.
 296  //
 297  // It does not perform all checks:
 298  //
 299  //   - It does not check the validity of individual proposals (section 12.1).
 300  //   - It does not check whether members in add proposals are already part of
 301  //     the group.
 302  //   - It does not check whether non-default proposal types are supported by
 303  //     all members of the group who will process the commit.
 304  //   - It does not check whether the ratchet tree is valid after processing the
 305  //     commit.
 306  func verifyProposalList(proposals []proposal, senders []leafIndex, committer leafIndex) error {
 307  	if len(proposals) != len(senders) {
 308  		panic("unreachable")
 309  	}
 310  
 311  	add := make(map[string]struct{})
 312  	updateOrRemove := make(map[leafIndex]struct{})
 313  	psk := make(map[string]struct{})
 314  	groupContextExtensions := false
 315  	for i, prop := range proposals {
 316  		sender := senders[i]
 317  
 318  		switch prop.proposalType {
 319  		case proposalTypeAdd:
 320  			k := string(prop.add.keyPackage.leafNode.signatureKey)
 321  			if _, dup := add[k]; dup {
 322  				return fmt.Errorf("mls: multiple add proposals have the same signature key")
 323  			}
 324  			add[k] = struct{}{}
 325  		case proposalTypeUpdate:
 326  			if sender == committer {
 327  				return fmt.Errorf("mls: update proposal generated by the committer")
 328  			}
 329  			if _, dup := updateOrRemove[sender]; dup {
 330  				return fmt.Errorf("mls: multiple update and/or remove proposals apply to the same leaf")
 331  			}
 332  			updateOrRemove[sender] = struct{}{}
 333  		case proposalTypeRemove:
 334  			if prop.remove.removed == committer {
 335  				return fmt.Errorf("mls: remove proposal removes the committer")
 336  			}
 337  			if _, dup := updateOrRemove[prop.remove.removed]; dup {
 338  				return fmt.Errorf("mls: multiple update and/or remove proposals apply to the same leaf")
 339  			}
 340  			updateOrRemove[prop.remove.removed] = struct{}{}
 341  		case proposalTypePSK:
 342  			b, err := marshal(&prop.preSharedKey.psk)
 343  			if err != nil {
 344  				return err
 345  			}
 346  			k := string(b)
 347  			if _, dup := psk[k]; dup {
 348  				return fmt.Errorf("mls: multiple PSK proposals reference the same PSK ID")
 349  			}
 350  			psk[k] = struct{}{}
 351  		case proposalTypeGroupContextExtensions:
 352  			if groupContextExtensions {
 353  				return fmt.Errorf("mls: multiple group context extensions proposals")
 354  			}
 355  			groupContextExtensions = true
 356  		case proposalTypeReinit:
 357  			if len(proposals) > 1 {
 358  				return fmt.Errorf("mls: reinit proposal together with any other proposal")
 359  			}
 360  		case proposalTypeExternalInit:
 361  			return fmt.Errorf("mls: external init proposal is not allowed")
 362  		}
 363  	}
 364  	return nil
 365  }
 366  
 367  func proposalListNeedsPath(proposals []proposal) bool {
 368  	if len(proposals) == 0 {
 369  		return true
 370  	}
 371  
 372  	for _, prop := range proposals {
 373  		switch prop.proposalType {
 374  		case proposalTypeUpdate, proposalTypeRemove, proposalTypeExternalInit, proposalTypeGroupContextExtensions:
 375  			return true
 376  		}
 377  	}
 378  
 379  	return false
 380  }
 381