buffer.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proto
   6  
   7  import (
   8  	"errors"
   9  	"fmt"
  10  
  11  	"google.golang.org/protobuf/encoding/prototext"
  12  	"google.golang.org/protobuf/encoding/protowire"
  13  	"google.golang.org/protobuf/runtime/protoimpl"
  14  )
  15  
  16  const (
  17  	WireVarint     = 0
  18  	WireFixed32    = 5
  19  	WireFixed64    = 1
  20  	WireBytes      = 2
  21  	WireStartGroup = 3
  22  	WireEndGroup   = 4
  23  )
  24  
  25  // EncodeVarint returns the varint encoded bytes of v.
  26  func EncodeVarint(v uint64) []byte {
  27  	return protowire.AppendVarint(nil, v)
  28  }
  29  
  30  // SizeVarint returns the length of the varint encoded bytes of v.
  31  // This is equal to len(EncodeVarint(v)).
  32  func SizeVarint(v uint64) int {
  33  	return protowire.SizeVarint(v)
  34  }
  35  
  36  // DecodeVarint parses a varint encoded integer from b,
  37  // returning the integer value and the length of the varint.
  38  // It returns (0, 0) if there is a parse error.
  39  func DecodeVarint(b []byte) (uint64, int) {
  40  	v, n := protowire.ConsumeVarint(b)
  41  	if n < 0 {
  42  		return 0, 0
  43  	}
  44  	return v, n
  45  }
  46  
  47  // Buffer is a buffer for encoding and decoding the protobuf wire format.
  48  // It may be reused between invocations to reduce memory usage.
  49  type Buffer struct {
  50  	buf           []byte
  51  	idx           int
  52  	deterministic bool
  53  }
  54  
  55  // NewBuffer allocates a new Buffer initialized with buf,
  56  // where the contents of buf are considered the unread portion of the buffer.
  57  func NewBuffer(buf []byte) *Buffer {
  58  	return &Buffer{buf: buf}
  59  }
  60  
  61  // SetDeterministic specifies whether to use deterministic serialization.
  62  //
  63  // Deterministic serialization guarantees that for a given binary, equal
  64  // messages will always be serialized to the same bytes. This implies:
  65  //
  66  //   - Repeated serialization of a message will return the same bytes.
  67  //   - Different processes of the same binary (which may be executing on
  68  //     different machines) will serialize equal messages to the same bytes.
  69  //
  70  // Note that the deterministic serialization is NOT canonical across
  71  // languages. It is not guaranteed to remain stable over time. It is unstable
  72  // across different builds with schema changes due to unknown fields.
  73  // Users who need canonical serialization (e.g., persistent storage in a
  74  // canonical form, fingerprinting, etc.) should define their own
  75  // canonicalization specification and implement their own serializer rather
  76  // than relying on this API.
  77  //
  78  // If deterministic serialization is requested, map entries will be sorted
  79  // by keys in lexographical order. This is an implementation detail and
  80  // subject to change.
  81  func (b *Buffer) SetDeterministic(deterministic bool) {
  82  	b.deterministic = deterministic
  83  }
  84  
  85  // SetBuf sets buf as the internal buffer,
  86  // where the contents of buf are considered the unread portion of the buffer.
  87  func (b *Buffer) SetBuf(buf []byte) {
  88  	b.buf = buf
  89  	b.idx = 0
  90  }
  91  
  92  // Reset clears the internal buffer of all written and unread data.
  93  func (b *Buffer) Reset() {
  94  	b.buf = b.buf[:0]
  95  	b.idx = 0
  96  }
  97  
  98  // Bytes returns the internal buffer.
  99  func (b *Buffer) Bytes() []byte {
 100  	return b.buf
 101  }
 102  
 103  // Unread returns the unread portion of the buffer.
 104  func (b *Buffer) Unread() []byte {
 105  	return b.buf[b.idx:]
 106  }
 107  
 108  // Marshal appends the wire-format encoding of m to the buffer.
 109  func (b *Buffer) Marshal(m Message) error {
 110  	var err error
 111  	b.buf, err = marshalAppend(b.buf, m, b.deterministic)
 112  	return err
 113  }
 114  
 115  // Unmarshal parses the wire-format message in the buffer and
 116  // places the decoded results in m.
 117  // It does not reset m before unmarshaling.
 118  func (b *Buffer) Unmarshal(m Message) error {
 119  	err := UnmarshalMerge(b.Unread(), m)
 120  	b.idx = len(b.buf)
 121  	return err
 122  }
 123  
 124  type unknownFields struct{ XXX_unrecognized protoimpl.UnknownFields }
 125  
 126  func (m *unknownFields) String() string { panic("not implemented") }
 127  func (m *unknownFields) Reset()         { panic("not implemented") }
 128  func (m *unknownFields) ProtoMessage()  { panic("not implemented") }
 129  
 130  // DebugPrint dumps the encoded bytes of b with a header and footer including s
 131  // to stdout. This is only intended for debugging.
 132  func (*Buffer) DebugPrint(s string, b []byte) {
 133  	m := MessageReflect(new(unknownFields))
 134  	m.SetUnknown(b)
 135  	b, _ = prototext.MarshalOptions{AllowPartial: true, Indent: "\t"}.Marshal(m.Interface())
 136  	fmt.Printf("==== %s ====\n%s==== %s ====\n", s, b, s)
 137  }
 138  
 139  // EncodeVarint appends an unsigned varint encoding to the buffer.
 140  func (b *Buffer) EncodeVarint(v uint64) error {
 141  	b.buf = protowire.AppendVarint(b.buf, v)
 142  	return nil
 143  }
 144  
 145  // EncodeZigzag32 appends a 32-bit zig-zag varint encoding to the buffer.
 146  func (b *Buffer) EncodeZigzag32(v uint64) error {
 147  	return b.EncodeVarint(uint64((uint32(v) << 1) ^ uint32((int32(v) >> 31))))
 148  }
 149  
 150  // EncodeZigzag64 appends a 64-bit zig-zag varint encoding to the buffer.
 151  func (b *Buffer) EncodeZigzag64(v uint64) error {
 152  	return b.EncodeVarint(uint64((uint64(v) << 1) ^ uint64((int64(v) >> 63))))
 153  }
 154  
 155  // EncodeFixed32 appends a 32-bit little-endian integer to the buffer.
 156  func (b *Buffer) EncodeFixed32(v uint64) error {
 157  	b.buf = protowire.AppendFixed32(b.buf, uint32(v))
 158  	return nil
 159  }
 160  
 161  // EncodeFixed64 appends a 64-bit little-endian integer to the buffer.
 162  func (b *Buffer) EncodeFixed64(v uint64) error {
 163  	b.buf = protowire.AppendFixed64(b.buf, uint64(v))
 164  	return nil
 165  }
 166  
 167  // EncodeRawBytes appends a length-prefixed raw bytes to the buffer.
 168  func (b *Buffer) EncodeRawBytes(v []byte) error {
 169  	b.buf = protowire.AppendBytes(b.buf, v)
 170  	return nil
 171  }
 172  
 173  // EncodeStringBytes appends a length-prefixed raw bytes to the buffer.
 174  // It does not validate whether v contains valid UTF-8.
 175  func (b *Buffer) EncodeStringBytes(v string) error {
 176  	b.buf = protowire.AppendString(b.buf, v)
 177  	return nil
 178  }
 179  
 180  // EncodeMessage appends a length-prefixed encoded message to the buffer.
 181  func (b *Buffer) EncodeMessage(m Message) error {
 182  	var err error
 183  	b.buf = protowire.AppendVarint(b.buf, uint64(Size(m)))
 184  	b.buf, err = marshalAppend(b.buf, m, b.deterministic)
 185  	return err
 186  }
 187  
 188  // DecodeVarint consumes an encoded unsigned varint from the buffer.
 189  func (b *Buffer) DecodeVarint() (uint64, error) {
 190  	v, n := protowire.ConsumeVarint(b.buf[b.idx:])
 191  	if n < 0 {
 192  		return 0, protowire.ParseError(n)
 193  	}
 194  	b.idx += n
 195  	return uint64(v), nil
 196  }
 197  
 198  // DecodeZigzag32 consumes an encoded 32-bit zig-zag varint from the buffer.
 199  func (b *Buffer) DecodeZigzag32() (uint64, error) {
 200  	v, err := b.DecodeVarint()
 201  	if err != nil {
 202  		return 0, err
 203  	}
 204  	return uint64((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31)), nil
 205  }
 206  
 207  // DecodeZigzag64 consumes an encoded 64-bit zig-zag varint from the buffer.
 208  func (b *Buffer) DecodeZigzag64() (uint64, error) {
 209  	v, err := b.DecodeVarint()
 210  	if err != nil {
 211  		return 0, err
 212  	}
 213  	return uint64((uint64(v) >> 1) ^ uint64((int64(v&1)<<63)>>63)), nil
 214  }
 215  
 216  // DecodeFixed32 consumes a 32-bit little-endian integer from the buffer.
 217  func (b *Buffer) DecodeFixed32() (uint64, error) {
 218  	v, n := protowire.ConsumeFixed32(b.buf[b.idx:])
 219  	if n < 0 {
 220  		return 0, protowire.ParseError(n)
 221  	}
 222  	b.idx += n
 223  	return uint64(v), nil
 224  }
 225  
 226  // DecodeFixed64 consumes a 64-bit little-endian integer from the buffer.
 227  func (b *Buffer) DecodeFixed64() (uint64, error) {
 228  	v, n := protowire.ConsumeFixed64(b.buf[b.idx:])
 229  	if n < 0 {
 230  		return 0, protowire.ParseError(n)
 231  	}
 232  	b.idx += n
 233  	return uint64(v), nil
 234  }
 235  
 236  // DecodeRawBytes consumes a length-prefixed raw bytes from the buffer.
 237  // If alloc is specified, it returns a copy the raw bytes
 238  // rather than a sub-slice of the buffer.
 239  func (b *Buffer) DecodeRawBytes(alloc bool) ([]byte, error) {
 240  	v, n := protowire.ConsumeBytes(b.buf[b.idx:])
 241  	if n < 0 {
 242  		return nil, protowire.ParseError(n)
 243  	}
 244  	b.idx += n
 245  	if alloc {
 246  		v = append([]byte(nil), v...)
 247  	}
 248  	return v, nil
 249  }
 250  
 251  // DecodeStringBytes consumes a length-prefixed raw bytes from the buffer.
 252  // It does not validate whether the raw bytes contain valid UTF-8.
 253  func (b *Buffer) DecodeStringBytes() (string, error) {
 254  	v, n := protowire.ConsumeString(b.buf[b.idx:])
 255  	if n < 0 {
 256  		return "", protowire.ParseError(n)
 257  	}
 258  	b.idx += n
 259  	return v, nil
 260  }
 261  
 262  // DecodeMessage consumes a length-prefixed message from the buffer.
 263  // It does not reset m before unmarshaling.
 264  func (b *Buffer) DecodeMessage(m Message) error {
 265  	v, err := b.DecodeRawBytes(false)
 266  	if err != nil {
 267  		return err
 268  	}
 269  	return UnmarshalMerge(v, m)
 270  }
 271  
 272  // DecodeGroup consumes a message group from the buffer.
 273  // It assumes that the start group marker has already been consumed and
 274  // consumes all bytes until (and including the end group marker).
 275  // It does not reset m before unmarshaling.
 276  func (b *Buffer) DecodeGroup(m Message) error {
 277  	v, n, err := consumeGroup(b.buf[b.idx:])
 278  	if err != nil {
 279  		return err
 280  	}
 281  	b.idx += n
 282  	return UnmarshalMerge(v, m)
 283  }
 284  
 285  // consumeGroup parses b until it finds an end group marker, returning
 286  // the raw bytes of the message (excluding the end group marker) and the
 287  // the total length of the message (including the end group marker).
 288  func consumeGroup(b []byte) ([]byte, int, error) {
 289  	b0 := b
 290  	depth := 1 // assume this follows a start group marker
 291  	for {
 292  		_, wtyp, tagLen := protowire.ConsumeTag(b)
 293  		if tagLen < 0 {
 294  			return nil, 0, protowire.ParseError(tagLen)
 295  		}
 296  		b = b[tagLen:]
 297  
 298  		var valLen int
 299  		switch wtyp {
 300  		case protowire.VarintType:
 301  			_, valLen = protowire.ConsumeVarint(b)
 302  		case protowire.Fixed32Type:
 303  			_, valLen = protowire.ConsumeFixed32(b)
 304  		case protowire.Fixed64Type:
 305  			_, valLen = protowire.ConsumeFixed64(b)
 306  		case protowire.BytesType:
 307  			_, valLen = protowire.ConsumeBytes(b)
 308  		case protowire.StartGroupType:
 309  			depth++
 310  		case protowire.EndGroupType:
 311  			depth--
 312  		default:
 313  			return nil, 0, errors.New("proto: cannot parse reserved wire type")
 314  		}
 315  		if valLen < 0 {
 316  			return nil, 0, protowire.ParseError(valLen)
 317  		}
 318  		b = b[valLen:]
 319  
 320  		if depth == 0 {
 321  			return b0[:len(b0)-len(b)-tagLen], len(b0) - len(b), nil
 322  		}
 323  	}
 324  }
 325