1 // Copyright 2019 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 package proto
6 7 import (
8 "errors"
9 "fmt"
10 11 "google.golang.org/protobuf/encoding/prototext"
12 "google.golang.org/protobuf/encoding/protowire"
13 "google.golang.org/protobuf/runtime/protoimpl"
14 )
15 16 const (
17 WireVarint = 0
18 WireFixed32 = 5
19 WireFixed64 = 1
20 WireBytes = 2
21 WireStartGroup = 3
22 WireEndGroup = 4
23 )
24 25 // EncodeVarint returns the varint encoded bytes of v.
26 func EncodeVarint(v uint64) []byte {
27 return protowire.AppendVarint(nil, v)
28 }
29 30 // SizeVarint returns the length of the varint encoded bytes of v.
31 // This is equal to len(EncodeVarint(v)).
32 func SizeVarint(v uint64) int {
33 return protowire.SizeVarint(v)
34 }
35 36 // DecodeVarint parses a varint encoded integer from b,
37 // returning the integer value and the length of the varint.
38 // It returns (0, 0) if there is a parse error.
39 func DecodeVarint(b []byte) (uint64, int) {
40 v, n := protowire.ConsumeVarint(b)
41 if n < 0 {
42 return 0, 0
43 }
44 return v, n
45 }
46 47 // Buffer is a buffer for encoding and decoding the protobuf wire format.
48 // It may be reused between invocations to reduce memory usage.
49 type Buffer struct {
50 buf []byte
51 idx int
52 deterministic bool
53 }
54 55 // NewBuffer allocates a new Buffer initialized with buf,
56 // where the contents of buf are considered the unread portion of the buffer.
57 func NewBuffer(buf []byte) *Buffer {
58 return &Buffer{buf: buf}
59 }
60 61 // SetDeterministic specifies whether to use deterministic serialization.
62 //
63 // Deterministic serialization guarantees that for a given binary, equal
64 // messages will always be serialized to the same bytes. This implies:
65 //
66 // - Repeated serialization of a message will return the same bytes.
67 // - Different processes of the same binary (which may be executing on
68 // different machines) will serialize equal messages to the same bytes.
69 //
70 // Note that the deterministic serialization is NOT canonical across
71 // languages. It is not guaranteed to remain stable over time. It is unstable
72 // across different builds with schema changes due to unknown fields.
73 // Users who need canonical serialization (e.g., persistent storage in a
74 // canonical form, fingerprinting, etc.) should define their own
75 // canonicalization specification and implement their own serializer rather
76 // than relying on this API.
77 //
78 // If deterministic serialization is requested, map entries will be sorted
79 // by keys in lexographical order. This is an implementation detail and
80 // subject to change.
81 func (b *Buffer) SetDeterministic(deterministic bool) {
82 b.deterministic = deterministic
83 }
84 85 // SetBuf sets buf as the internal buffer,
86 // where the contents of buf are considered the unread portion of the buffer.
87 func (b *Buffer) SetBuf(buf []byte) {
88 b.buf = buf
89 b.idx = 0
90 }
91 92 // Reset clears the internal buffer of all written and unread data.
93 func (b *Buffer) Reset() {
94 b.buf = b.buf[:0]
95 b.idx = 0
96 }
97 98 // Bytes returns the internal buffer.
99 func (b *Buffer) Bytes() []byte {
100 return b.buf
101 }
102 103 // Unread returns the unread portion of the buffer.
104 func (b *Buffer) Unread() []byte {
105 return b.buf[b.idx:]
106 }
107 108 // Marshal appends the wire-format encoding of m to the buffer.
109 func (b *Buffer) Marshal(m Message) error {
110 var err error
111 b.buf, err = marshalAppend(b.buf, m, b.deterministic)
112 return err
113 }
114 115 // Unmarshal parses the wire-format message in the buffer and
116 // places the decoded results in m.
117 // It does not reset m before unmarshaling.
118 func (b *Buffer) Unmarshal(m Message) error {
119 err := UnmarshalMerge(b.Unread(), m)
120 b.idx = len(b.buf)
121 return err
122 }
123 124 type unknownFields struct{ XXX_unrecognized protoimpl.UnknownFields }
125 126 func (m *unknownFields) String() string { panic("not implemented") }
127 func (m *unknownFields) Reset() { panic("not implemented") }
128 func (m *unknownFields) ProtoMessage() { panic("not implemented") }
129 130 // DebugPrint dumps the encoded bytes of b with a header and footer including s
131 // to stdout. This is only intended for debugging.
132 func (*Buffer) DebugPrint(s string, b []byte) {
133 m := MessageReflect(new(unknownFields))
134 m.SetUnknown(b)
135 b, _ = prototext.MarshalOptions{AllowPartial: true, Indent: "\t"}.Marshal(m.Interface())
136 fmt.Printf("==== %s ====\n%s==== %s ====\n", s, b, s)
137 }
138 139 // EncodeVarint appends an unsigned varint encoding to the buffer.
140 func (b *Buffer) EncodeVarint(v uint64) error {
141 b.buf = protowire.AppendVarint(b.buf, v)
142 return nil
143 }
144 145 // EncodeZigzag32 appends a 32-bit zig-zag varint encoding to the buffer.
146 func (b *Buffer) EncodeZigzag32(v uint64) error {
147 return b.EncodeVarint(uint64((uint32(v) << 1) ^ uint32((int32(v) >> 31))))
148 }
149 150 // EncodeZigzag64 appends a 64-bit zig-zag varint encoding to the buffer.
151 func (b *Buffer) EncodeZigzag64(v uint64) error {
152 return b.EncodeVarint(uint64((uint64(v) << 1) ^ uint64((int64(v) >> 63))))
153 }
154 155 // EncodeFixed32 appends a 32-bit little-endian integer to the buffer.
156 func (b *Buffer) EncodeFixed32(v uint64) error {
157 b.buf = protowire.AppendFixed32(b.buf, uint32(v))
158 return nil
159 }
160 161 // EncodeFixed64 appends a 64-bit little-endian integer to the buffer.
162 func (b *Buffer) EncodeFixed64(v uint64) error {
163 b.buf = protowire.AppendFixed64(b.buf, uint64(v))
164 return nil
165 }
166 167 // EncodeRawBytes appends a length-prefixed raw bytes to the buffer.
168 func (b *Buffer) EncodeRawBytes(v []byte) error {
169 b.buf = protowire.AppendBytes(b.buf, v)
170 return nil
171 }
172 173 // EncodeStringBytes appends a length-prefixed raw bytes to the buffer.
174 // It does not validate whether v contains valid UTF-8.
175 func (b *Buffer) EncodeStringBytes(v string) error {
176 b.buf = protowire.AppendString(b.buf, v)
177 return nil
178 }
179 180 // EncodeMessage appends a length-prefixed encoded message to the buffer.
181 func (b *Buffer) EncodeMessage(m Message) error {
182 var err error
183 b.buf = protowire.AppendVarint(b.buf, uint64(Size(m)))
184 b.buf, err = marshalAppend(b.buf, m, b.deterministic)
185 return err
186 }
187 188 // DecodeVarint consumes an encoded unsigned varint from the buffer.
189 func (b *Buffer) DecodeVarint() (uint64, error) {
190 v, n := protowire.ConsumeVarint(b.buf[b.idx:])
191 if n < 0 {
192 return 0, protowire.ParseError(n)
193 }
194 b.idx += n
195 return uint64(v), nil
196 }
197 198 // DecodeZigzag32 consumes an encoded 32-bit zig-zag varint from the buffer.
199 func (b *Buffer) DecodeZigzag32() (uint64, error) {
200 v, err := b.DecodeVarint()
201 if err != nil {
202 return 0, err
203 }
204 return uint64((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31)), nil
205 }
206 207 // DecodeZigzag64 consumes an encoded 64-bit zig-zag varint from the buffer.
208 func (b *Buffer) DecodeZigzag64() (uint64, error) {
209 v, err := b.DecodeVarint()
210 if err != nil {
211 return 0, err
212 }
213 return uint64((uint64(v) >> 1) ^ uint64((int64(v&1)<<63)>>63)), nil
214 }
215 216 // DecodeFixed32 consumes a 32-bit little-endian integer from the buffer.
217 func (b *Buffer) DecodeFixed32() (uint64, error) {
218 v, n := protowire.ConsumeFixed32(b.buf[b.idx:])
219 if n < 0 {
220 return 0, protowire.ParseError(n)
221 }
222 b.idx += n
223 return uint64(v), nil
224 }
225 226 // DecodeFixed64 consumes a 64-bit little-endian integer from the buffer.
227 func (b *Buffer) DecodeFixed64() (uint64, error) {
228 v, n := protowire.ConsumeFixed64(b.buf[b.idx:])
229 if n < 0 {
230 return 0, protowire.ParseError(n)
231 }
232 b.idx += n
233 return uint64(v), nil
234 }
235 236 // DecodeRawBytes consumes a length-prefixed raw bytes from the buffer.
237 // If alloc is specified, it returns a copy the raw bytes
238 // rather than a sub-slice of the buffer.
239 func (b *Buffer) DecodeRawBytes(alloc bool) ([]byte, error) {
240 v, n := protowire.ConsumeBytes(b.buf[b.idx:])
241 if n < 0 {
242 return nil, protowire.ParseError(n)
243 }
244 b.idx += n
245 if alloc {
246 v = append([]byte(nil), v...)
247 }
248 return v, nil
249 }
250 251 // DecodeStringBytes consumes a length-prefixed raw bytes from the buffer.
252 // It does not validate whether the raw bytes contain valid UTF-8.
253 func (b *Buffer) DecodeStringBytes() (string, error) {
254 v, n := protowire.ConsumeString(b.buf[b.idx:])
255 if n < 0 {
256 return "", protowire.ParseError(n)
257 }
258 b.idx += n
259 return v, nil
260 }
261 262 // DecodeMessage consumes a length-prefixed message from the buffer.
263 // It does not reset m before unmarshaling.
264 func (b *Buffer) DecodeMessage(m Message) error {
265 v, err := b.DecodeRawBytes(false)
266 if err != nil {
267 return err
268 }
269 return UnmarshalMerge(v, m)
270 }
271 272 // DecodeGroup consumes a message group from the buffer.
273 // It assumes that the start group marker has already been consumed and
274 // consumes all bytes until (and including the end group marker).
275 // It does not reset m before unmarshaling.
276 func (b *Buffer) DecodeGroup(m Message) error {
277 v, n, err := consumeGroup(b.buf[b.idx:])
278 if err != nil {
279 return err
280 }
281 b.idx += n
282 return UnmarshalMerge(v, m)
283 }
284 285 // consumeGroup parses b until it finds an end group marker, returning
286 // the raw bytes of the message (excluding the end group marker) and the
287 // the total length of the message (including the end group marker).
288 func consumeGroup(b []byte) ([]byte, int, error) {
289 b0 := b
290 depth := 1 // assume this follows a start group marker
291 for {
292 _, wtyp, tagLen := protowire.ConsumeTag(b)
293 if tagLen < 0 {
294 return nil, 0, protowire.ParseError(tagLen)
295 }
296 b = b[tagLen:]
297 298 var valLen int
299 switch wtyp {
300 case protowire.VarintType:
301 _, valLen = protowire.ConsumeVarint(b)
302 case protowire.Fixed32Type:
303 _, valLen = protowire.ConsumeFixed32(b)
304 case protowire.Fixed64Type:
305 _, valLen = protowire.ConsumeFixed64(b)
306 case protowire.BytesType:
307 _, valLen = protowire.ConsumeBytes(b)
308 case protowire.StartGroupType:
309 depth++
310 case protowire.EndGroupType:
311 depth--
312 default:
313 return nil, 0, errors.New("proto: cannot parse reserved wire type")
314 }
315 if valLen < 0 {
316 return nil, 0, protowire.ParseError(valLen)
317 }
318 b = b[valLen:]
319 320 if depth == 0 {
321 return b0[:len(b0)-len(b)-tagLen], len(b0) - len(b), nil
322 }
323 }
324 }
325