proto.mx raw

   1  // Copyright 2014 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // This file is a simple protocol buffer encoder and decoder.
   6  //
   7  // A protocol message must implement the message interface:
   8  //   decoder() []decoder
   9  //   encode(*buffer)
  10  //
  11  // The decode method returns a slice indexed by field number that gives the
  12  // function to decode that field.
  13  // The encode method encodes its receiver into the given buffer.
  14  //
  15  // The two methods are simple enough to be implemented by hand rather than
  16  // by using a protocol compiler.
  17  //
  18  // See profile.go for examples of messages implementing this interface.
  19  //
  20  // There is no support for groups, message sets, or "has" bits.
  21  
  22  package profile
  23  
  24  import (
  25  	"errors"
  26  	"fmt"
  27  )
  28  
  29  type buffer struct {
  30  	field int
  31  	typ   int
  32  	u64   uint64
  33  	data  []byte
  34  	tmp   [16]byte
  35  }
  36  
  37  type decoder func(*buffer, message) error
  38  
  39  type message interface {
  40  	decoder() []decoder
  41  	encode(*buffer)
  42  }
  43  
  44  func marshal(m message) []byte {
  45  	var b buffer
  46  	m.encode(&b)
  47  	return b.data
  48  }
  49  
  50  func encodeVarint(b *buffer, x uint64) {
  51  	for x >= 128 {
  52  		b.data = append(b.data, byte(x)|0x80)
  53  		x >>= 7
  54  	}
  55  	b.data = append(b.data, byte(x))
  56  }
  57  
  58  func encodeLength(b *buffer, tag int, len int) {
  59  	encodeVarint(b, uint64(tag)<<3|2)
  60  	encodeVarint(b, uint64(len))
  61  }
  62  
  63  func encodeUint64(b *buffer, tag int, x uint64) {
  64  	// append varint to b.data
  65  	encodeVarint(b, uint64(tag)<<3|0)
  66  	encodeVarint(b, x)
  67  }
  68  
  69  func encodeUint64s(b *buffer, tag int, x []uint64) {
  70  	if len(x) > 2 {
  71  		// Use packed encoding
  72  		n1 := len(b.data)
  73  		for _, u := range x {
  74  			encodeVarint(b, u)
  75  		}
  76  		n2 := len(b.data)
  77  		encodeLength(b, tag, n2-n1)
  78  		n3 := len(b.data)
  79  		copy(b.tmp[:], b.data[n2:n3])
  80  		copy(b.data[n1+(n3-n2):], b.data[n1:n2])
  81  		copy(b.data[n1:], b.tmp[:n3-n2])
  82  		return
  83  	}
  84  	for _, u := range x {
  85  		encodeUint64(b, tag, u)
  86  	}
  87  }
  88  
  89  func encodeUint64Opt(b *buffer, tag int, x uint64) {
  90  	if x == 0 {
  91  		return
  92  	}
  93  	encodeUint64(b, tag, x)
  94  }
  95  
  96  func encodeInt64(b *buffer, tag int, x int64) {
  97  	u := uint64(x)
  98  	encodeUint64(b, tag, u)
  99  }
 100  
 101  func encodeInt64Opt(b *buffer, tag int, x int64) {
 102  	if x == 0 {
 103  		return
 104  	}
 105  	encodeInt64(b, tag, x)
 106  }
 107  
 108  func encodeInt64s(b *buffer, tag int, x []int64) {
 109  	if len(x) > 2 {
 110  		// Use packed encoding
 111  		n1 := len(b.data)
 112  		for _, u := range x {
 113  			encodeVarint(b, uint64(u))
 114  		}
 115  		n2 := len(b.data)
 116  		encodeLength(b, tag, n2-n1)
 117  		n3 := len(b.data)
 118  		copy(b.tmp[:], b.data[n2:n3])
 119  		copy(b.data[n1+(n3-n2):], b.data[n1:n2])
 120  		copy(b.data[n1:], b.tmp[:n3-n2])
 121  		return
 122  	}
 123  	for _, u := range x {
 124  		encodeInt64(b, tag, u)
 125  	}
 126  }
 127  
 128  func encodeString(b *buffer, tag int, x []byte) {
 129  	encodeLength(b, tag, len(x))
 130  	b.data = append(b.data, x...)
 131  }
 132  
 133  func encodeStrings(b *buffer, tag int, x [][]byte) {
 134  	for _, s := range x {
 135  		encodeString(b, tag, s)
 136  	}
 137  }
 138  
 139  func encodeBool(b *buffer, tag int, x bool) {
 140  	if x {
 141  		encodeUint64(b, tag, 1)
 142  	} else {
 143  		encodeUint64(b, tag, 0)
 144  	}
 145  }
 146  
 147  func encodeBoolOpt(b *buffer, tag int, x bool) {
 148  	if !x {
 149  		return
 150  	}
 151  	encodeBool(b, tag, x)
 152  }
 153  
 154  func encodeMessage(b *buffer, tag int, m message) {
 155  	n1 := len(b.data)
 156  	m.encode(b)
 157  	n2 := len(b.data)
 158  	encodeLength(b, tag, n2-n1)
 159  	n3 := len(b.data)
 160  	copy(b.tmp[:], b.data[n2:n3])
 161  	copy(b.data[n1+(n3-n2):], b.data[n1:n2])
 162  	copy(b.data[n1:], b.tmp[:n3-n2])
 163  }
 164  
 165  func unmarshal(data []byte, m message) (err error) {
 166  	b := buffer{data: data, typ: 2}
 167  	return decodeMessage(&b, m)
 168  }
 169  
 170  func le64(p []byte) uint64 {
 171  	return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
 172  }
 173  
 174  func le32(p []byte) uint32 {
 175  	return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
 176  }
 177  
 178  func decodeVarint(data []byte) (uint64, []byte, error) {
 179  	var i int
 180  	var u uint64
 181  	for i = 0; ; i++ {
 182  		if i >= 10 || i >= len(data) {
 183  			return 0, nil, errors.New("bad varint")
 184  		}
 185  		u |= uint64(data[i]&0x7F) << uint(7*i)
 186  		if data[i]&0x80 == 0 {
 187  			return u, data[i+1:], nil
 188  		}
 189  	}
 190  }
 191  
 192  func decodeField(b *buffer, data []byte) ([]byte, error) {
 193  	x, data, err := decodeVarint(data)
 194  	if err != nil {
 195  		return nil, err
 196  	}
 197  	b.field = int(x >> 3)
 198  	b.typ = int(x & 7)
 199  	b.data = nil
 200  	b.u64 = 0
 201  	switch b.typ {
 202  	case 0:
 203  		b.u64, data, err = decodeVarint(data)
 204  		if err != nil {
 205  			return nil, err
 206  		}
 207  	case 1:
 208  		if len(data) < 8 {
 209  			return nil, errors.New("not enough data")
 210  		}
 211  		b.u64 = le64(data[:8])
 212  		data = data[8:]
 213  	case 2:
 214  		var n uint64
 215  		n, data, err = decodeVarint(data)
 216  		if err != nil {
 217  			return nil, err
 218  		}
 219  		if n > uint64(len(data)) {
 220  			return nil, errors.New("too much data")
 221  		}
 222  		b.data = data[:n]
 223  		data = data[n:]
 224  	case 5:
 225  		if len(data) < 4 {
 226  			return nil, errors.New("not enough data")
 227  		}
 228  		b.u64 = uint64(le32(data[:4]))
 229  		data = data[4:]
 230  	default:
 231  		return nil, fmt.Errorf("unknown wire type: %d", b.typ)
 232  	}
 233  
 234  	return data, nil
 235  }
 236  
 237  func checkType(b *buffer, typ int) error {
 238  	if b.typ != typ {
 239  		return errors.New("type mismatch")
 240  	}
 241  	return nil
 242  }
 243  
 244  func decodeMessage(b *buffer, m message) error {
 245  	if err := checkType(b, 2); err != nil {
 246  		return err
 247  	}
 248  	dec := m.decoder()
 249  	data := b.data
 250  	for len(data) > 0 {
 251  		// pull varint field# + type
 252  		var err error
 253  		data, err = decodeField(b, data)
 254  		if err != nil {
 255  			return err
 256  		}
 257  		if b.field >= len(dec) || dec[b.field] == nil {
 258  			continue
 259  		}
 260  		if err := dec[b.field](b, m); err != nil {
 261  			return err
 262  		}
 263  	}
 264  	return nil
 265  }
 266  
 267  func decodeInt64(b *buffer, x *int64) error {
 268  	if err := checkType(b, 0); err != nil {
 269  		return err
 270  	}
 271  	*x = int64(b.u64)
 272  	return nil
 273  }
 274  
 275  func decodeInt64s(b *buffer, x *[]int64) error {
 276  	if b.typ == 2 {
 277  		// Packed encoding
 278  		data := b.data
 279  		for len(data) > 0 {
 280  			var u uint64
 281  			var err error
 282  
 283  			if u, data, err = decodeVarint(data); err != nil {
 284  				return err
 285  			}
 286  			*x = append(*x, int64(u))
 287  		}
 288  		return nil
 289  	}
 290  	var i int64
 291  	if err := decodeInt64(b, &i); err != nil {
 292  		return err
 293  	}
 294  	*x = append(*x, i)
 295  	return nil
 296  }
 297  
 298  func decodeUint64(b *buffer, x *uint64) error {
 299  	if err := checkType(b, 0); err != nil {
 300  		return err
 301  	}
 302  	*x = b.u64
 303  	return nil
 304  }
 305  
 306  func decodeUint64s(b *buffer, x *[]uint64) error {
 307  	if b.typ == 2 {
 308  		data := b.data
 309  		// Packed encoding
 310  		for len(data) > 0 {
 311  			var u uint64
 312  			var err error
 313  
 314  			if u, data, err = decodeVarint(data); err != nil {
 315  				return err
 316  			}
 317  			*x = append(*x, u)
 318  		}
 319  		return nil
 320  	}
 321  	var u uint64
 322  	if err := decodeUint64(b, &u); err != nil {
 323  		return err
 324  	}
 325  	*x = append(*x, u)
 326  	return nil
 327  }
 328  
 329  func decodeString(b *buffer, x *[]byte) error {
 330  	if err := checkType(b, 2); err != nil {
 331  		return err
 332  	}
 333  	*x = []byte(b.data)
 334  	return nil
 335  }
 336  
 337  func decodeStrings(b *buffer, x *[][]byte) error {
 338  	var s []byte
 339  	if err := decodeString(b, &s); err != nil {
 340  		return err
 341  	}
 342  	*x = append(*x, s)
 343  	return nil
 344  }
 345  
 346  func decodeBool(b *buffer, x *bool) error {
 347  	if err := checkType(b, 0); err != nil {
 348  		return err
 349  	}
 350  	if int64(b.u64) == 0 {
 351  		*x = false
 352  	} else {
 353  		*x = true
 354  	}
 355  	return nil
 356  }
 357