mls.go raw

   1  // Package mls implements the Messaging Layer Security protocol.
   2  //
   3  // MLS is specified in RFC 9420.
   4  package mls
   5  
   6  import (
   7  	"fmt"
   8  	"io"
   9  
  10  	"golang.org/x/crypto/cryptobyte"
  11  )
  12  
  13  func readVarint(s *cryptobyte.String, out *uint32) bool {
  14  	var b uint8
  15  	if !s.ReadUint8(&b) {
  16  		return false
  17  	}
  18  
  19  	prefix := b >> 6
  20  	if prefix == 3 {
  21  		return false // invalid variable length integer prefix
  22  	}
  23  
  24  	n := 1 << prefix
  25  	v := uint32(b & 0x3F)
  26  	for i := 0; i < n-1; i++ {
  27  		if !s.ReadUint8(&b) {
  28  			return false
  29  		}
  30  		v = (v << 8) + uint32(b)
  31  	}
  32  
  33  	if prefix >= 1 && v < uint32(1)<<(8*(n/2)-2) {
  34  		return false // minimum encoding was not used
  35  	}
  36  
  37  	*out = v
  38  	return true
  39  }
  40  
  41  func writeVarint(b *cryptobyte.Builder, n uint32) {
  42  	switch {
  43  	case n < 1<<6:
  44  		b.AddUint8(uint8(n))
  45  	case n < 1<<14:
  46  		b.AddUint16(0b01<<14 | uint16(n))
  47  	case n < 1<<30:
  48  		b.AddUint32(0b10<<30 | n)
  49  	default:
  50  		b.SetError(fmt.Errorf("mls: varint exceeds 30 bits"))
  51  	}
  52  }
  53  
  54  func readOpaqueVec(s *cryptobyte.String, out *[]byte) bool {
  55  	var n uint32
  56  	if !readVarint(s, &n) {
  57  		return false
  58  	}
  59  
  60  	b := make([]byte, n)
  61  	if !s.CopyBytes(b) {
  62  		return false
  63  	}
  64  
  65  	*out = b
  66  	return true
  67  }
  68  
  69  func writeOpaqueVec(b *cryptobyte.Builder, value []byte) {
  70  	if uint64(len(value)) >= 1<<32 {
  71  		b.SetError(fmt.Errorf("mls: opaque size exceeds maximum value of uint32"))
  72  		return
  73  	}
  74  	writeVarint(b, uint32(len(value)))
  75  	b.AddBytes(value)
  76  }
  77  
  78  func readVector(s *cryptobyte.String, f func(s *cryptobyte.String) error) error {
  79  	var n uint32
  80  	if !readVarint(s, &n) {
  81  		return io.ErrUnexpectedEOF
  82  	}
  83  	var vec []byte
  84  	if !s.ReadBytes(&vec, int(n)) {
  85  		return io.ErrUnexpectedEOF
  86  	}
  87  	ss := cryptobyte.String(vec)
  88  	for !ss.Empty() {
  89  		if err := f(&ss); err != nil {
  90  			return err
  91  		}
  92  	}
  93  	return nil
  94  }
  95  
  96  func writeVector(b *cryptobyte.Builder, n int, f func(b *cryptobyte.Builder, i int)) {
  97  	// We don't know the total size in advance, and the vector is prefixed with
  98  	// a varint, so we can't avoid the temporary buffer here
  99  	var child cryptobyte.Builder
 100  	for i := 0; i < n; i++ {
 101  		f(&child, i)
 102  	}
 103  
 104  	raw, err := child.Bytes()
 105  	if err != nil {
 106  		b.SetError(err)
 107  		return
 108  	}
 109  
 110  	writeOpaqueVec(b, raw)
 111  }
 112  
 113  func readOptional(s *cryptobyte.String, present *bool) bool {
 114  	var u8 uint8
 115  	if !s.ReadUint8(&u8) {
 116  		return false
 117  	}
 118  	switch u8 {
 119  	case 0:
 120  		*present = false
 121  	case 1:
 122  		*present = true
 123  	default:
 124  		return false
 125  	}
 126  	return true
 127  }
 128  
 129  func writeOptional(b *cryptobyte.Builder, present bool) {
 130  	u8 := uint8(0)
 131  	if present {
 132  		u8 = 1
 133  	}
 134  	b.AddUint8(u8)
 135  }
 136  
 137  type unmarshaler interface {
 138  	unmarshal(*cryptobyte.String) error
 139  }
 140  
 141  type marshaler interface {
 142  	marshal(*cryptobyte.Builder)
 143  }
 144  
 145  func unmarshal(raw []byte, v unmarshaler) error {
 146  	s := cryptobyte.String(raw)
 147  	if err := v.unmarshal(&s); err != nil {
 148  		return err
 149  	}
 150  	if !s.Empty() {
 151  		return fmt.Errorf("mls: input for %T contains %v excess bytes", v, len(s))
 152  	}
 153  	return nil
 154  }
 155  
 156  func marshal(v marshaler) ([]byte, error) {
 157  	var b cryptobyte.Builder
 158  	v.marshal(&b)
 159  	return b.Bytes()
 160  }
 161