1 // Copyright 2010 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 package proto
6 7 import (
8 "errors"
9 "fmt"
10 "reflect"
11 12 "google.golang.org/protobuf/encoding/protowire"
13 "google.golang.org/protobuf/proto"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 "google.golang.org/protobuf/reflect/protoregistry"
16 "google.golang.org/protobuf/runtime/protoiface"
17 "google.golang.org/protobuf/runtime/protoimpl"
18 )
19 20 type (
21 // ExtensionDesc represents an extension descriptor and
22 // is used to interact with an extension field in a message.
23 //
24 // Variables of this type are generated in code by protoc-gen-go.
25 ExtensionDesc = protoimpl.ExtensionInfo
26 27 // ExtensionRange represents a range of message extensions.
28 // Used in code generated by protoc-gen-go.
29 ExtensionRange = protoiface.ExtensionRangeV1
30 31 // Deprecated: Do not use; this is an internal type.
32 Extension = protoimpl.ExtensionFieldV1
33 34 // Deprecated: Do not use; this is an internal type.
35 XXX_InternalExtensions = protoimpl.ExtensionFields
36 )
37 38 // ErrMissingExtension reports whether the extension was not present.
39 var ErrMissingExtension = errors.New("proto: missing extension")
40 41 var errNotExtendable = errors.New("proto: not an extendable proto.Message")
42 43 // HasExtension reports whether the extension field is present in m
44 // either as an explicitly populated field or as an unknown field.
45 func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
46 mr := MessageReflect(m)
47 if mr == nil || !mr.IsValid() {
48 return false
49 }
50 51 // Check whether any populated known field matches the field number.
52 xtd := xt.TypeDescriptor()
53 if isValidExtension(mr.Descriptor(), xtd) {
54 has = mr.Has(xtd)
55 } else {
56 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
57 has = int32(fd.Number()) == xt.Field
58 return !has
59 })
60 }
61 62 // Check whether any unknown field matches the field number.
63 for b := mr.GetUnknown(); !has && len(b) > 0; {
64 num, _, n := protowire.ConsumeField(b)
65 has = int32(num) == xt.Field
66 b = b[n:]
67 }
68 return has
69 }
70 71 // ClearExtension removes the extension field from m
72 // either as an explicitly populated field or as an unknown field.
73 func ClearExtension(m Message, xt *ExtensionDesc) {
74 mr := MessageReflect(m)
75 if mr == nil || !mr.IsValid() {
76 return
77 }
78 79 xtd := xt.TypeDescriptor()
80 if isValidExtension(mr.Descriptor(), xtd) {
81 mr.Clear(xtd)
82 } else {
83 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
84 if int32(fd.Number()) == xt.Field {
85 mr.Clear(fd)
86 return false
87 }
88 return true
89 })
90 }
91 clearUnknown(mr, fieldNum(xt.Field))
92 }
93 94 // ClearAllExtensions clears all extensions from m.
95 // This includes populated fields and unknown fields in the extension range.
96 func ClearAllExtensions(m Message) {
97 mr := MessageReflect(m)
98 if mr == nil || !mr.IsValid() {
99 return
100 }
101 102 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
103 if fd.IsExtension() {
104 mr.Clear(fd)
105 }
106 return true
107 })
108 clearUnknown(mr, mr.Descriptor().ExtensionRanges())
109 }
110 111 // GetExtension retrieves a proto2 extended field from m.
112 //
113 // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
114 // then GetExtension parses the encoded field and returns a Go value of the specified type.
115 // If the field is not present, then the default value is returned (if one is specified),
116 // otherwise ErrMissingExtension is reported.
117 //
118 // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
119 // then GetExtension returns the raw encoded bytes for the extension field.
120 func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
121 mr := MessageReflect(m)
122 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
123 return nil, errNotExtendable
124 }
125 126 // Retrieve the unknown fields for this extension field.
127 var bo protoreflect.RawFields
128 for bi := mr.GetUnknown(); len(bi) > 0; {
129 num, _, n := protowire.ConsumeField(bi)
130 if int32(num) == xt.Field {
131 bo = append(bo, bi[:n]...)
132 }
133 bi = bi[n:]
134 }
135 136 // For type incomplete descriptors, only retrieve the unknown fields.
137 if xt.ExtensionType == nil {
138 return []byte(bo), nil
139 }
140 141 // If the extension field only exists as unknown fields, unmarshal it.
142 // This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
143 xtd := xt.TypeDescriptor()
144 if !isValidExtension(mr.Descriptor(), xtd) {
145 return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
146 }
147 if !mr.Has(xtd) && len(bo) > 0 {
148 m2 := mr.New()
149 if err := (proto.UnmarshalOptions{
150 Resolver: extensionResolver{xt},
151 }.Unmarshal(bo, m2.Interface())); err != nil {
152 return nil, err
153 }
154 if m2.Has(xtd) {
155 mr.Set(xtd, m2.Get(xtd))
156 clearUnknown(mr, fieldNum(xt.Field))
157 }
158 }
159 160 // Check whether the message has the extension field set or a default.
161 var pv protoreflect.Value
162 switch {
163 case mr.Has(xtd):
164 pv = mr.Get(xtd)
165 case xtd.HasDefault():
166 pv = xtd.Default()
167 default:
168 return nil, ErrMissingExtension
169 }
170 171 v := xt.InterfaceOf(pv)
172 rv := reflect.ValueOf(v)
173 if isScalarKind(rv.Kind()) {
174 rv2 := reflect.New(rv.Type())
175 rv2.Elem().Set(rv)
176 v = rv2.Interface()
177 }
178 return v, nil
179 }
180 181 // extensionResolver is a custom extension resolver that stores a single
182 // extension type that takes precedence over the global registry.
183 type extensionResolver struct{ xt protoreflect.ExtensionType }
184 185 func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
186 if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
187 return r.xt, nil
188 }
189 return protoregistry.GlobalTypes.FindExtensionByName(field)
190 }
191 192 func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
193 if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
194 return r.xt, nil
195 }
196 return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
197 }
198 199 // GetExtensions returns a list of the extensions values present in m,
200 // corresponding with the provided list of extension descriptors, xts.
201 // If an extension is missing in m, the corresponding value is nil.
202 func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
203 mr := MessageReflect(m)
204 if mr == nil || !mr.IsValid() {
205 return nil, errNotExtendable
206 }
207 208 vs := make([]interface{}, len(xts))
209 for i, xt := range xts {
210 v, err := GetExtension(m, xt)
211 if err != nil {
212 if err == ErrMissingExtension {
213 continue
214 }
215 return vs, err
216 }
217 vs[i] = v
218 }
219 return vs, nil
220 }
221 222 // SetExtension sets an extension field in m to the provided value.
223 func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
224 mr := MessageReflect(m)
225 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
226 return errNotExtendable
227 }
228 229 rv := reflect.ValueOf(v)
230 if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
231 return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
232 }
233 if rv.Kind() == reflect.Ptr {
234 if rv.IsNil() {
235 return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
236 }
237 if isScalarKind(rv.Elem().Kind()) {
238 v = rv.Elem().Interface()
239 }
240 }
241 242 xtd := xt.TypeDescriptor()
243 if !isValidExtension(mr.Descriptor(), xtd) {
244 return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
245 }
246 mr.Set(xtd, xt.ValueOf(v))
247 clearUnknown(mr, fieldNum(xt.Field))
248 return nil
249 }
250 251 // SetRawExtension inserts b into the unknown fields of m.
252 //
253 // Deprecated: Use Message.ProtoReflect.SetUnknown instead.
254 func SetRawExtension(m Message, fnum int32, b []byte) {
255 mr := MessageReflect(m)
256 if mr == nil || !mr.IsValid() {
257 return
258 }
259 260 // Verify that the raw field is valid.
261 for b0 := b; len(b0) > 0; {
262 num, _, n := protowire.ConsumeField(b0)
263 if int32(num) != fnum {
264 panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
265 }
266 b0 = b0[n:]
267 }
268 269 ClearExtension(m, &ExtensionDesc{Field: fnum})
270 mr.SetUnknown(append(mr.GetUnknown(), b...))
271 }
272 273 // ExtensionDescs returns a list of extension descriptors found in m,
274 // containing descriptors for both populated extension fields in m and
275 // also unknown fields of m that are in the extension range.
276 // For the later case, an type incomplete descriptor is provided where only
277 // the ExtensionDesc.Field field is populated.
278 // The order of the extension descriptors is undefined.
279 func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
280 mr := MessageReflect(m)
281 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
282 return nil, errNotExtendable
283 }
284 285 // Collect a set of known extension descriptors.
286 extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
287 mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
288 if fd.IsExtension() {
289 xt := fd.(protoreflect.ExtensionTypeDescriptor)
290 if xd, ok := xt.Type().(*ExtensionDesc); ok {
291 extDescs[fd.Number()] = xd
292 }
293 }
294 return true
295 })
296 297 // Collect a set of unknown extension descriptors.
298 extRanges := mr.Descriptor().ExtensionRanges()
299 for b := mr.GetUnknown(); len(b) > 0; {
300 num, _, n := protowire.ConsumeField(b)
301 if extRanges.Has(num) && extDescs[num] == nil {
302 extDescs[num] = nil
303 }
304 b = b[n:]
305 }
306 307 // Transpose the set of descriptors into a list.
308 var xts []*ExtensionDesc
309 for num, xt := range extDescs {
310 if xt == nil {
311 xt = &ExtensionDesc{Field: int32(num)}
312 }
313 xts = append(xts, xt)
314 }
315 return xts, nil
316 }
317 318 // isValidExtension reports whether xtd is a valid extension descriptor for md.
319 func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
320 return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
321 }
322 323 // isScalarKind reports whether k is a protobuf scalar kind (except bytes).
324 // This function exists for historical reasons since the representation of
325 // scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
326 func isScalarKind(k reflect.Kind) bool {
327 switch k {
328 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
329 return true
330 default:
331 return false
332 }
333 }
334 335 // clearUnknown removes unknown fields from m where remover.Has reports true.
336 func clearUnknown(m protoreflect.Message, remover interface {
337 Has(protoreflect.FieldNumber) bool
338 }) {
339 var bo protoreflect.RawFields
340 for bi := m.GetUnknown(); len(bi) > 0; {
341 num, _, n := protowire.ConsumeField(bi)
342 if !remover.Has(num) {
343 bo = append(bo, bi[:n]...)
344 }
345 bi = bi[n:]
346 }
347 if bi := m.GetUnknown(); len(bi) != len(bo) {
348 m.SetUnknown(bo)
349 }
350 }
351 352 type fieldNum protoreflect.FieldNumber
353 354 func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
355 return protoreflect.FieldNumber(n1) == n2
356 }
357