1 // Copyright 2018 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 // Package protowire parses and formats the raw wire encoding.
6 // See https://protobuf.dev/programming-guides/encoding.
7 //
8 // For marshaling and unmarshaling entire protobuf messages,
9 // use the [google.golang.org/protobuf/proto] package instead.
10 package protowire
11 12 import (
13 "io"
14 "math"
15 "math/bits"
16 17 "google.golang.org/protobuf/internal/errors"
18 )
19 20 // Number represents the field number.
21 type Number int32
22 23 const (
24 MinValidNumber Number = 1
25 FirstReservedNumber Number = 19000
26 LastReservedNumber Number = 19999
27 MaxValidNumber Number = 1<<29 - 1
28 DefaultRecursionLimit = 10000
29 )
30 31 // IsValid reports whether the field number is semantically valid.
32 func (n Number) IsValid() bool {
33 return MinValidNumber <= n && n <= MaxValidNumber
34 }
35 36 // Type represents the wire type.
37 type Type int8
38 39 const (
40 VarintType Type = 0
41 Fixed32Type Type = 5
42 Fixed64Type Type = 1
43 BytesType Type = 2
44 StartGroupType Type = 3
45 EndGroupType Type = 4
46 )
47 48 const (
49 _ = -iota
50 errCodeTruncated
51 errCodeFieldNumber
52 errCodeOverflow
53 errCodeReserved
54 errCodeEndGroup
55 errCodeRecursionDepth
56 )
57 58 var (
59 errFieldNumber = errors.New("invalid field number")
60 errOverflow = errors.New("variable length integer overflow")
61 errReserved = errors.New("cannot parse reserved wire type")
62 errEndGroup = errors.New("mismatching end group marker")
63 errParse = errors.New("parse error")
64 )
65 66 // ParseError converts an error code into an error value.
67 // This returns nil if n is a non-negative number.
68 func ParseError(n int) error {
69 if n >= 0 {
70 return nil
71 }
72 switch n {
73 case errCodeTruncated:
74 return io.ErrUnexpectedEOF
75 case errCodeFieldNumber:
76 return errFieldNumber
77 case errCodeOverflow:
78 return errOverflow
79 case errCodeReserved:
80 return errReserved
81 case errCodeEndGroup:
82 return errEndGroup
83 default:
84 return errParse
85 }
86 }
87 88 // ConsumeField parses an entire field record (both tag and value) and returns
89 // the field number, the wire type, and the total length.
90 // This returns a negative length upon an error (see [ParseError]).
91 //
92 // The total length includes the tag header and the end group marker (if the
93 // field is a group).
94 func ConsumeField(b []byte) (Number, Type, int) {
95 num, typ, n := ConsumeTag(b)
96 if n < 0 {
97 return 0, 0, n // forward error code
98 }
99 m := ConsumeFieldValue(num, typ, b[n:])
100 if m < 0 {
101 return 0, 0, m // forward error code
102 }
103 return num, typ, n + m
104 }
105 106 // ConsumeFieldValue parses a field value and returns its length.
107 // This assumes that the field [Number] and wire [Type] have already been parsed.
108 // This returns a negative length upon an error (see [ParseError]).
109 //
110 // When parsing a group, the length includes the end group marker and
111 // the end group is verified to match the starting field number.
112 func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
113 return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
114 }
115 116 func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
117 switch typ {
118 case VarintType:
119 _, n = ConsumeVarint(b)
120 return n
121 case Fixed32Type:
122 _, n = ConsumeFixed32(b)
123 return n
124 case Fixed64Type:
125 _, n = ConsumeFixed64(b)
126 return n
127 case BytesType:
128 _, n = ConsumeBytes(b)
129 return n
130 case StartGroupType:
131 if depth < 0 {
132 return errCodeRecursionDepth
133 }
134 n0 := len(b)
135 for {
136 num2, typ2, n := ConsumeTag(b)
137 if n < 0 {
138 return n // forward error code
139 }
140 b = b[n:]
141 if typ2 == EndGroupType {
142 if num != num2 {
143 return errCodeEndGroup
144 }
145 return n0 - len(b)
146 }
147 148 n = consumeFieldValueD(num2, typ2, b, depth-1)
149 if n < 0 {
150 return n // forward error code
151 }
152 b = b[n:]
153 }
154 case EndGroupType:
155 return errCodeEndGroup
156 default:
157 return errCodeReserved
158 }
159 }
160 161 // AppendTag encodes num and typ as a varint-encoded tag and appends it to b.
162 func AppendTag(b []byte, num Number, typ Type) []byte {
163 return AppendVarint(b, EncodeTag(num, typ))
164 }
165 166 // ConsumeTag parses b as a varint-encoded tag, reporting its length.
167 // This returns a negative length upon an error (see [ParseError]).
168 func ConsumeTag(b []byte) (Number, Type, int) {
169 v, n := ConsumeVarint(b)
170 if n < 0 {
171 return 0, 0, n // forward error code
172 }
173 num, typ := DecodeTag(v)
174 if num < MinValidNumber {
175 return 0, 0, errCodeFieldNumber
176 }
177 return num, typ, n
178 }
179 180 func SizeTag(num Number) int {
181 return SizeVarint(EncodeTag(num, 0)) // wire type has no effect on size
182 }
183 184 // AppendVarint appends v to b as a varint-encoded uint64.
185 func AppendVarint(b []byte, v uint64) []byte {
186 switch {
187 case v < 1<<7:
188 b = append(b, byte(v))
189 case v < 1<<14:
190 b = append(b,
191 byte((v>>0)&0x7f|0x80),
192 byte(v>>7))
193 case v < 1<<21:
194 b = append(b,
195 byte((v>>0)&0x7f|0x80),
196 byte((v>>7)&0x7f|0x80),
197 byte(v>>14))
198 case v < 1<<28:
199 b = append(b,
200 byte((v>>0)&0x7f|0x80),
201 byte((v>>7)&0x7f|0x80),
202 byte((v>>14)&0x7f|0x80),
203 byte(v>>21))
204 case v < 1<<35:
205 b = append(b,
206 byte((v>>0)&0x7f|0x80),
207 byte((v>>7)&0x7f|0x80),
208 byte((v>>14)&0x7f|0x80),
209 byte((v>>21)&0x7f|0x80),
210 byte(v>>28))
211 case v < 1<<42:
212 b = append(b,
213 byte((v>>0)&0x7f|0x80),
214 byte((v>>7)&0x7f|0x80),
215 byte((v>>14)&0x7f|0x80),
216 byte((v>>21)&0x7f|0x80),
217 byte((v>>28)&0x7f|0x80),
218 byte(v>>35))
219 case v < 1<<49:
220 b = append(b,
221 byte((v>>0)&0x7f|0x80),
222 byte((v>>7)&0x7f|0x80),
223 byte((v>>14)&0x7f|0x80),
224 byte((v>>21)&0x7f|0x80),
225 byte((v>>28)&0x7f|0x80),
226 byte((v>>35)&0x7f|0x80),
227 byte(v>>42))
228 case v < 1<<56:
229 b = append(b,
230 byte((v>>0)&0x7f|0x80),
231 byte((v>>7)&0x7f|0x80),
232 byte((v>>14)&0x7f|0x80),
233 byte((v>>21)&0x7f|0x80),
234 byte((v>>28)&0x7f|0x80),
235 byte((v>>35)&0x7f|0x80),
236 byte((v>>42)&0x7f|0x80),
237 byte(v>>49))
238 case v < 1<<63:
239 b = append(b,
240 byte((v>>0)&0x7f|0x80),
241 byte((v>>7)&0x7f|0x80),
242 byte((v>>14)&0x7f|0x80),
243 byte((v>>21)&0x7f|0x80),
244 byte((v>>28)&0x7f|0x80),
245 byte((v>>35)&0x7f|0x80),
246 byte((v>>42)&0x7f|0x80),
247 byte((v>>49)&0x7f|0x80),
248 byte(v>>56))
249 default:
250 b = append(b,
251 byte((v>>0)&0x7f|0x80),
252 byte((v>>7)&0x7f|0x80),
253 byte((v>>14)&0x7f|0x80),
254 byte((v>>21)&0x7f|0x80),
255 byte((v>>28)&0x7f|0x80),
256 byte((v>>35)&0x7f|0x80),
257 byte((v>>42)&0x7f|0x80),
258 byte((v>>49)&0x7f|0x80),
259 byte((v>>56)&0x7f|0x80),
260 1)
261 }
262 return b
263 }
264 265 // ConsumeVarint parses b as a varint-encoded uint64, reporting its length.
266 // This returns a negative length upon an error (see [ParseError]).
267 func ConsumeVarint(b []byte) (v uint64, n int) {
268 var y uint64
269 if len(b) <= 0 {
270 return 0, errCodeTruncated
271 }
272 v = uint64(b[0])
273 if v < 0x80 {
274 return v, 1
275 }
276 v -= 0x80
277 278 if len(b) <= 1 {
279 return 0, errCodeTruncated
280 }
281 y = uint64(b[1])
282 v += y << 7
283 if y < 0x80 {
284 return v, 2
285 }
286 v -= 0x80 << 7
287 288 if len(b) <= 2 {
289 return 0, errCodeTruncated
290 }
291 y = uint64(b[2])
292 v += y << 14
293 if y < 0x80 {
294 return v, 3
295 }
296 v -= 0x80 << 14
297 298 if len(b) <= 3 {
299 return 0, errCodeTruncated
300 }
301 y = uint64(b[3])
302 v += y << 21
303 if y < 0x80 {
304 return v, 4
305 }
306 v -= 0x80 << 21
307 308 if len(b) <= 4 {
309 return 0, errCodeTruncated
310 }
311 y = uint64(b[4])
312 v += y << 28
313 if y < 0x80 {
314 return v, 5
315 }
316 v -= 0x80 << 28
317 318 if len(b) <= 5 {
319 return 0, errCodeTruncated
320 }
321 y = uint64(b[5])
322 v += y << 35
323 if y < 0x80 {
324 return v, 6
325 }
326 v -= 0x80 << 35
327 328 if len(b) <= 6 {
329 return 0, errCodeTruncated
330 }
331 y = uint64(b[6])
332 v += y << 42
333 if y < 0x80 {
334 return v, 7
335 }
336 v -= 0x80 << 42
337 338 if len(b) <= 7 {
339 return 0, errCodeTruncated
340 }
341 y = uint64(b[7])
342 v += y << 49
343 if y < 0x80 {
344 return v, 8
345 }
346 v -= 0x80 << 49
347 348 if len(b) <= 8 {
349 return 0, errCodeTruncated
350 }
351 y = uint64(b[8])
352 v += y << 56
353 if y < 0x80 {
354 return v, 9
355 }
356 v -= 0x80 << 56
357 358 if len(b) <= 9 {
359 return 0, errCodeTruncated
360 }
361 y = uint64(b[9])
362 v += y << 63
363 if y < 2 {
364 return v, 10
365 }
366 return 0, errCodeOverflow
367 }
368 369 // SizeVarint returns the encoded size of a varint.
370 // The size is guaranteed to be within 1 and 10, inclusive.
371 func SizeVarint(v uint64) int {
372 // This computes 1 + (bits.Len64(v)-1)/7.
373 // 9/64 is a good enough approximation of 1/7
374 //
375 // The Go compiler can translate the bits.LeadingZeros64 call into the LZCNT
376 // instruction, which is very fast on CPUs from the last few years. The
377 // specific way of expressing the calculation matches C++ Protobuf, see
378 // https://godbolt.org/z/4P3h53oM4 for the C++ code and how gcc/clang
379 // optimize that function for GOAMD64=v1 and GOAMD64=v3 (-march=haswell).
380 381 // By OR'ing v with 1, we guarantee that v is never 0, without changing the
382 // result of SizeVarint. LZCNT is not defined for 0, meaning the compiler
383 // needs to add extra instructions to handle that case.
384 //
385 // The Go compiler currently (go1.24.4) does not make use of this knowledge.
386 // This opportunity (removing the XOR instruction, which handles the 0 case)
387 // results in a small (1%) performance win across CPU architectures.
388 //
389 // Independently of avoiding the 0 case, we need the v |= 1 line because
390 // it allows the Go compiler to eliminate an extra XCHGL barrier.
391 v |= 1
392 393 // It would be clearer to write log2value := 63 - uint32(...), but
394 // writing uint32(...) ^ 63 is much more efficient (-14% ARM, -20% Intel).
395 // Proof of identity for our value range [0..63]:
396 // https://go.dev/play/p/Pdn9hEWYakX
397 log2value := uint32(bits.LeadingZeros64(v)) ^ 63
398 return int((log2value*9 + (64 + 9)) / 64)
399 }
400 401 // AppendFixed32 appends v to b as a little-endian uint32.
402 func AppendFixed32(b []byte, v uint32) []byte {
403 return append(b,
404 byte(v>>0),
405 byte(v>>8),
406 byte(v>>16),
407 byte(v>>24))
408 }
409 410 // ConsumeFixed32 parses b as a little-endian uint32, reporting its length.
411 // This returns a negative length upon an error (see [ParseError]).
412 func ConsumeFixed32(b []byte) (v uint32, n int) {
413 if len(b) < 4 {
414 return 0, errCodeTruncated
415 }
416 v = uint32(b[0])<<0 | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
417 return v, 4
418 }
419 420 // SizeFixed32 returns the encoded size of a fixed32; which is always 4.
421 func SizeFixed32() int {
422 return 4
423 }
424 425 // AppendFixed64 appends v to b as a little-endian uint64.
426 func AppendFixed64(b []byte, v uint64) []byte {
427 return append(b,
428 byte(v>>0),
429 byte(v>>8),
430 byte(v>>16),
431 byte(v>>24),
432 byte(v>>32),
433 byte(v>>40),
434 byte(v>>48),
435 byte(v>>56))
436 }
437 438 // ConsumeFixed64 parses b as a little-endian uint64, reporting its length.
439 // This returns a negative length upon an error (see [ParseError]).
440 func ConsumeFixed64(b []byte) (v uint64, n int) {
441 if len(b) < 8 {
442 return 0, errCodeTruncated
443 }
444 v = uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
445 return v, 8
446 }
447 448 // SizeFixed64 returns the encoded size of a fixed64; which is always 8.
449 func SizeFixed64() int {
450 return 8
451 }
452 453 // AppendBytes appends v to b as a length-prefixed bytes value.
454 func AppendBytes(b []byte, v []byte) []byte {
455 return append(AppendVarint(b, uint64(len(v))), v...)
456 }
457 458 // ConsumeBytes parses b as a length-prefixed bytes value, reporting its length.
459 // This returns a negative length upon an error (see [ParseError]).
460 func ConsumeBytes(b []byte) (v []byte, n int) {
461 m, n := ConsumeVarint(b)
462 if n < 0 {
463 return nil, n // forward error code
464 }
465 if m > uint64(len(b[n:])) {
466 return nil, errCodeTruncated
467 }
468 return b[n:][:m], n + int(m)
469 }
470 471 // SizeBytes returns the encoded size of a length-prefixed bytes value,
472 // given only the length.
473 func SizeBytes(n int) int {
474 return SizeVarint(uint64(n)) + n
475 }
476 477 // AppendString appends v to b as a length-prefixed bytes value.
478 func AppendString(b []byte, v string) []byte {
479 return append(AppendVarint(b, uint64(len(v))), v...)
480 }
481 482 // ConsumeString parses b as a length-prefixed bytes value, reporting its length.
483 // This returns a negative length upon an error (see [ParseError]).
484 func ConsumeString(b []byte) (v string, n int) {
485 bb, n := ConsumeBytes(b)
486 return string(bb), n
487 }
488 489 // AppendGroup appends v to b as group value, with a trailing end group marker.
490 // The value v must not contain the end marker.
491 func AppendGroup(b []byte, num Number, v []byte) []byte {
492 return AppendVarint(append(b, v...), EncodeTag(num, EndGroupType))
493 }
494 495 // ConsumeGroup parses b as a group value until the trailing end group marker,
496 // and verifies that the end marker matches the provided num. The value v
497 // does not contain the end marker, while the length does contain the end marker.
498 // This returns a negative length upon an error (see [ParseError]).
499 func ConsumeGroup(num Number, b []byte) (v []byte, n int) {
500 n = ConsumeFieldValue(num, StartGroupType, b)
501 if n < 0 {
502 return nil, n // forward error code
503 }
504 b = b[:n]
505 506 // Truncate off end group marker, but need to handle denormalized varints.
507 // Assuming end marker is never 0 (which is always the case since
508 // EndGroupType is non-zero), we can truncate all trailing bytes where the
509 // lower 7 bits are all zero (implying that the varint is denormalized).
510 for len(b) > 0 && b[len(b)-1]&0x7f == 0 {
511 b = b[:len(b)-1]
512 }
513 b = b[:len(b)-SizeTag(num)]
514 return b, n
515 }
516 517 // SizeGroup returns the encoded size of a group, given only the length.
518 func SizeGroup(num Number, n int) int {
519 return n + SizeTag(num)
520 }
521 522 // DecodeTag decodes the field [Number] and wire [Type] from its unified form.
523 // The [Number] is -1 if the decoded field number overflows int32.
524 // Other than overflow, this does not check for field number validity.
525 func DecodeTag(x uint64) (Number, Type) {
526 // NOTE: MessageSet allows for larger field numbers than normal.
527 if x>>3 > uint64(math.MaxInt32) {
528 return -1, 0
529 }
530 return Number(x >> 3), Type(x & 7)
531 }
532 533 // EncodeTag encodes the field [Number] and wire [Type] into its unified form.
534 func EncodeTag(num Number, typ Type) uint64 {
535 return uint64(num)<<3 | uint64(typ&7)
536 }
537 538 // DecodeZigZag decodes a zig-zag-encoded uint64 as an int64.
539 //
540 // Input: {…, 5, 3, 1, 0, 2, 4, 6, …}
541 // Output: {…, -3, -2, -1, 0, +1, +2, +3, …}
542 func DecodeZigZag(x uint64) int64 {
543 return int64(x>>1) ^ int64(x)<<63>>63
544 }
545 546 // EncodeZigZag encodes an int64 as a zig-zag-encoded uint64.
547 //
548 // Input: {…, -3, -2, -1, 0, +1, +2, +3, …}
549 // Output: {…, 5, 3, 1, 0, 2, 4, 6, …}
550 func EncodeZigZag(x int64) uint64 {
551 return uint64(x<<1) ^ uint64(x>>63)
552 }
553 554 // DecodeBool decodes a uint64 as a bool.
555 //
556 // Input: { 0, 1, 2, …}
557 // Output: {false, true, true, …}
558 func DecodeBool(x uint64) bool {
559 return x != 0
560 }
561 562 // EncodeBool encodes a bool as a uint64.
563 //
564 // Input: {false, true}
565 // Output: { 0, 1}
566 func EncodeBool(x bool) uint64 {
567 if x {
568 return 1
569 }
570 return 0
571 }
572