1 // Copyright 2018 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 package proto
6 7 import (
8 "google.golang.org/protobuf/encoding/protowire"
9 "google.golang.org/protobuf/internal/encoding/messageset"
10 "google.golang.org/protobuf/internal/errors"
11 "google.golang.org/protobuf/internal/genid"
12 "google.golang.org/protobuf/internal/pragma"
13 "google.golang.org/protobuf/reflect/protoreflect"
14 "google.golang.org/protobuf/reflect/protoregistry"
15 "google.golang.org/protobuf/runtime/protoiface"
16 )
17 18 // UnmarshalOptions configures the unmarshaler.
19 //
20 // Example usage:
21 //
22 // err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
23 type UnmarshalOptions struct {
24 pragma.NoUnkeyedLiterals
25 26 // Merge merges the input into the destination message.
27 // The default behavior is to always reset the message before unmarshaling,
28 // unless Merge is specified.
29 Merge bool
30 31 // AllowPartial accepts input for messages that will result in missing
32 // required fields. If AllowPartial is false (the default), Unmarshal will
33 // return an error if there are any missing required fields.
34 AllowPartial bool
35 36 // If DiscardUnknown is set, unknown fields are ignored.
37 DiscardUnknown bool
38 39 // Resolver is used for looking up types when unmarshaling extension fields.
40 // If nil, this defaults to using protoregistry.GlobalTypes.
41 Resolver interface {
42 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
43 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
44 }
45 46 // RecursionLimit limits how deeply messages may be nested.
47 // If zero, a default limit is applied.
48 RecursionLimit int
49 50 //
51 // NoLazyDecoding turns off lazy decoding, which otherwise is enabled by
52 // default. Lazy decoding only affects submessages (annotated with [lazy =
53 // true] in the .proto file) within messages that use the Opaque API.
54 NoLazyDecoding bool
55 }
56 57 // Unmarshal parses the wire-format message in b and places the result in m.
58 // The provided message must be mutable (e.g., a non-nil pointer to a message).
59 //
60 // See the [UnmarshalOptions] type if you need more control.
61 func Unmarshal(b []byte, m Message) error {
62 _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
63 return err
64 }
65 66 // Unmarshal parses the wire-format message in b and places the result in m.
67 // The provided message must be mutable (e.g., a non-nil pointer to a message).
68 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
69 if o.RecursionLimit == 0 {
70 o.RecursionLimit = protowire.DefaultRecursionLimit
71 }
72 _, err := o.unmarshal(b, m.ProtoReflect())
73 return err
74 }
75 76 // UnmarshalState parses a wire-format message and places the result in m.
77 //
78 // This method permits fine-grained control over the unmarshaler.
79 // Most users should use [Unmarshal] instead.
80 func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
81 if o.RecursionLimit == 0 {
82 o.RecursionLimit = protowire.DefaultRecursionLimit
83 }
84 return o.unmarshal(in.Buf, in.Message)
85 }
86 87 // unmarshal is a centralized function that all unmarshal operations go through.
88 // For profiling purposes, avoid changing the name of this function or
89 // introducing other code paths for unmarshal that do not go through this.
90 func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
91 if o.Resolver == nil {
92 o.Resolver = protoregistry.GlobalTypes
93 }
94 if !o.Merge {
95 Reset(m.Interface())
96 }
97 allowPartial := o.AllowPartial
98 o.Merge = true
99 o.AllowPartial = true
100 methods := protoMethods(m)
101 if methods != nil && methods.Unmarshal != nil &&
102 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
103 in := protoiface.UnmarshalInput{
104 Message: m,
105 Buf: b,
106 Resolver: o.Resolver,
107 Depth: o.RecursionLimit,
108 }
109 if o.DiscardUnknown {
110 in.Flags |= protoiface.UnmarshalDiscardUnknown
111 }
112 113 if !allowPartial {
114 // This does not affect how current unmarshal functions work, it just allows them
115 // to record this for lazy the decoding case.
116 in.Flags |= protoiface.UnmarshalCheckRequired
117 }
118 if o.NoLazyDecoding {
119 in.Flags |= protoiface.UnmarshalNoLazyDecoding
120 }
121 122 out, err = methods.Unmarshal(in)
123 } else {
124 if o.RecursionLimit--; o.RecursionLimit < 0 {
125 return out, errRecursionDepth
126 }
127 err = o.unmarshalMessageSlow(b, m)
128 }
129 if err != nil {
130 return out, err
131 }
132 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
133 return out, nil
134 }
135 return out, checkInitialized(m)
136 }
137 138 func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
139 _, err := o.unmarshal(b, m)
140 return err
141 }
142 143 func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
144 md := m.Descriptor()
145 if messageset.IsMessageSet(md) {
146 return o.unmarshalMessageSet(b, m)
147 }
148 fields := md.Fields()
149 for len(b) > 0 {
150 // Parse the tag (field number and wire type).
151 num, wtyp, tagLen := protowire.ConsumeTag(b)
152 if tagLen < 0 {
153 return errDecode
154 }
155 if num > protowire.MaxValidNumber {
156 return errDecode
157 }
158 159 // Find the field descriptor for this field number.
160 fd := fields.ByNumber(num)
161 if fd == nil && md.ExtensionRanges().Has(num) {
162 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
163 if err != nil && err != protoregistry.NotFound {
164 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
165 }
166 if extType != nil {
167 fd = extType.TypeDescriptor()
168 }
169 }
170 var err error
171 if fd == nil {
172 err = errUnknown
173 }
174 175 // Parse the field value.
176 var valLen int
177 switch {
178 case err != nil:
179 case fd.IsList():
180 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
181 case fd.IsMap():
182 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
183 default:
184 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
185 }
186 if err != nil {
187 if err != errUnknown {
188 return err
189 }
190 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
191 if valLen < 0 {
192 return errDecode
193 }
194 if !o.DiscardUnknown {
195 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
196 }
197 }
198 b = b[tagLen+valLen:]
199 }
200 return nil
201 }
202 203 func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
204 v, n, err := o.unmarshalScalar(b, wtyp, fd)
205 if err != nil {
206 return 0, err
207 }
208 switch fd.Kind() {
209 case protoreflect.GroupKind, protoreflect.MessageKind:
210 m2 := m.Mutable(fd).Message()
211 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
212 return n, err
213 }
214 default:
215 // Non-message scalars replace the previous value.
216 m.Set(fd, v)
217 }
218 return n, nil
219 }
220 221 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
222 if o.RecursionLimit--; o.RecursionLimit < 0 {
223 return 0, errRecursionDepth
224 }
225 if wtyp != protowire.BytesType {
226 return 0, errUnknown
227 }
228 b, n = protowire.ConsumeBytes(b)
229 if n < 0 {
230 return 0, errDecode
231 }
232 var (
233 keyField = fd.MapKey()
234 valField = fd.MapValue()
235 key protoreflect.Value
236 val protoreflect.Value
237 haveKey bool
238 haveVal bool
239 )
240 switch valField.Kind() {
241 case protoreflect.GroupKind, protoreflect.MessageKind:
242 val = mapv.NewValue()
243 }
244 // Map entries are represented as a two-element message with fields
245 // containing the key and value.
246 for len(b) > 0 {
247 num, wtyp, n := protowire.ConsumeTag(b)
248 if n < 0 {
249 return 0, errDecode
250 }
251 if num > protowire.MaxValidNumber {
252 return 0, errDecode
253 }
254 b = b[n:]
255 err = errUnknown
256 switch num {
257 case genid.MapEntry_Key_field_number:
258 key, n, err = o.unmarshalScalar(b, wtyp, keyField)
259 if err != nil {
260 break
261 }
262 haveKey = true
263 case genid.MapEntry_Value_field_number:
264 var v protoreflect.Value
265 v, n, err = o.unmarshalScalar(b, wtyp, valField)
266 if err != nil {
267 break
268 }
269 switch valField.Kind() {
270 case protoreflect.GroupKind, protoreflect.MessageKind:
271 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
272 return 0, err
273 }
274 default:
275 val = v
276 }
277 haveVal = true
278 }
279 if err == errUnknown {
280 n = protowire.ConsumeFieldValue(num, wtyp, b)
281 if n < 0 {
282 return 0, errDecode
283 }
284 } else if err != nil {
285 return 0, err
286 }
287 b = b[n:]
288 }
289 // Every map entry should have entries for key and value, but this is not strictly required.
290 if !haveKey {
291 key = keyField.Default()
292 }
293 if !haveVal {
294 switch valField.Kind() {
295 case protoreflect.GroupKind, protoreflect.MessageKind:
296 default:
297 val = valField.Default()
298 }
299 }
300 mapv.Set(key.MapKey(), val)
301 return n, nil
302 }
303 304 // errUnknown is used internally to indicate fields which should be added
305 // to the unknown field set of a message. It is never returned from an exported
306 // function.
307 var errUnknown = errors.New("BUG: internal error (unknown)")
308 309 var errDecode = errors.New("cannot parse invalid wire-format data")
310 311 var errRecursionDepth = errors.New("exceeded maximum recursion depth")
312