struct_codec.go raw
1 // Copyright (C) MongoDB, Inc. 2017-present.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may
4 // not use this file except in compliance with the License. You may obtain
5 // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7 package bsoncodec
8
9 import (
10 "errors"
11 "fmt"
12 "reflect"
13 "sort"
14 "strings"
15 "sync"
16 "time"
17
18 "go.mongodb.org/mongo-driver/bson/bsonoptions"
19 "go.mongodb.org/mongo-driver/bson/bsonrw"
20 "go.mongodb.org/mongo-driver/bson/bsontype"
21 )
22
23 // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
24 type DecodeError struct {
25 keys []string
26 wrapped error
27 }
28
29 // Unwrap returns the underlying error
30 func (de *DecodeError) Unwrap() error {
31 return de.wrapped
32 }
33
34 // Error implements the error interface.
35 func (de *DecodeError) Error() string {
36 // The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
37 // stack of BSON keys, so we call de.Keys(), which reverses them.
38 keyPath := strings.Join(de.Keys(), ".")
39 return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
40 }
41
42 // Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
43 // order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
44 // a string, the keys slice will be ["a", "b", "c"].
45 func (de *DecodeError) Keys() []string {
46 reversedKeys := make([]string, 0, len(de.keys))
47 for idx := len(de.keys) - 1; idx >= 0; idx-- {
48 reversedKeys = append(reversedKeys, de.keys[idx])
49 }
50
51 return reversedKeys
52 }
53
54 // Zeroer allows custom struct types to implement a report of zero
55 // state. All struct types that don't implement Zeroer or where IsZero
56 // returns false are considered to be not zero.
57 type Zeroer interface {
58 IsZero() bool
59 }
60
61 // StructCodec is the Codec used for struct values.
62 //
63 // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
64 // StructCodec registered.
65 type StructCodec struct {
66 cache sync.Map // map[reflect.Type]*structDescription
67 parser StructTagParser
68
69 // DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the
70 // destination value passed to Decode before unmarshaling BSON documents into them.
71 //
72 // Deprecated: Use bson.Decoder.ZeroStructs instead.
73 DecodeZeroStruct bool
74
75 // DecodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the
76 // destination value passed to Decode before unmarshaling BSON documents into them.
77 //
78 // Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0.
79 DecodeDeepZeroInline bool
80
81 // EncodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g.
82 // MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag
83 // option is set.
84 //
85 // Deprecated: Use bson.Encoder.OmitZeroStruct instead.
86 EncodeOmitDefaultStruct bool
87
88 // AllowUnexportedFields allows encoding and decoding values from un-exported struct fields.
89 //
90 // Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be
91 // supported in Go Driver 2.0.
92 AllowUnexportedFields bool
93
94 // OverwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is
95 // a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The
96 // default value is true.
97 //
98 // Deprecated: Use bson.Encoder.ErrorOnInlineDuplicates instead.
99 OverwriteDuplicatedInlinedFields bool
100 }
101
102 var _ ValueEncoder = &StructCodec{}
103 var _ ValueDecoder = &StructCodec{}
104
105 // NewStructCodec returns a StructCodec that uses p for struct tag parsing.
106 //
107 // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
108 // StructCodec registered.
109 func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
110 if p == nil {
111 return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
112 }
113
114 structOpt := bsonoptions.MergeStructCodecOptions(opts...)
115
116 codec := &StructCodec{
117 parser: p,
118 }
119
120 if structOpt.DecodeZeroStruct != nil {
121 codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
122 }
123 if structOpt.DecodeDeepZeroInline != nil {
124 codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
125 }
126 if structOpt.EncodeOmitDefaultStruct != nil {
127 codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
128 }
129 if structOpt.OverwriteDuplicatedInlinedFields != nil {
130 codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields
131 }
132 if structOpt.AllowUnexportedFields != nil {
133 codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
134 }
135
136 return codec, nil
137 }
138
139 // EncodeValue handles encoding generic struct types.
140 func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
141 if !val.IsValid() || val.Kind() != reflect.Struct {
142 return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
143 }
144
145 sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates)
146 if err != nil {
147 return err
148 }
149
150 dw, err := vw.WriteDocument()
151 if err != nil {
152 return err
153 }
154 var rv reflect.Value
155 for _, desc := range sd.fl {
156 if desc.inline == nil {
157 rv = val.Field(desc.idx)
158 } else {
159 rv, err = fieldByIndexErr(val, desc.inline)
160 if err != nil {
161 continue
162 }
163 }
164
165 desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv)
166
167 if err != nil && err != errInvalidValue {
168 return err
169 }
170
171 if err == errInvalidValue {
172 if desc.omitEmpty {
173 continue
174 }
175 vw2, err := dw.WriteDocumentElement(desc.name)
176 if err != nil {
177 return err
178 }
179 err = vw2.WriteNull()
180 if err != nil {
181 return err
182 }
183 continue
184 }
185
186 if desc.encoder == nil {
187 return ErrNoEncoder{Type: rv.Type()}
188 }
189
190 encoder := desc.encoder
191
192 var zero bool
193 if cz, ok := encoder.(CodecZeroer); ok {
194 zero = cz.IsTypeZero(rv.Interface())
195 } else if rv.Kind() == reflect.Interface {
196 // isZero will not treat an interface rv as an interface, so we need to check for the
197 // zero interface separately.
198 zero = rv.IsNil()
199 } else {
200 zero = isZero(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
201 }
202 if desc.omitEmpty && zero {
203 continue
204 }
205
206 vw2, err := dw.WriteDocumentElement(desc.name)
207 if err != nil {
208 return err
209 }
210
211 ectx := EncodeContext{
212 Registry: ec.Registry,
213 MinSize: desc.minSize || ec.MinSize,
214 errorOnInlineDuplicates: ec.errorOnInlineDuplicates,
215 stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt,
216 nilMapAsEmpty: ec.nilMapAsEmpty,
217 nilSliceAsEmpty: ec.nilSliceAsEmpty,
218 nilByteSliceAsEmpty: ec.nilByteSliceAsEmpty,
219 omitZeroStruct: ec.omitZeroStruct,
220 useJSONStructTags: ec.useJSONStructTags,
221 }
222 err = encoder.EncodeValue(ectx, vw2, rv)
223 if err != nil {
224 return err
225 }
226 }
227
228 if sd.inlineMap >= 0 {
229 rv := val.Field(sd.inlineMap)
230 collisionFn := func(key string) bool {
231 _, exists := sd.fm[key]
232 return exists
233 }
234
235 return defaultMapCodec.mapEncodeValue(ec, dw, rv, collisionFn)
236 }
237
238 return dw.WriteDocumentEnd()
239 }
240
241 func newDecodeError(key string, original error) error {
242 de, ok := original.(*DecodeError)
243 if !ok {
244 return &DecodeError{
245 keys: []string{key},
246 wrapped: original,
247 }
248 }
249
250 de.keys = append(de.keys, key)
251 return de
252 }
253
254 // DecodeValue implements the Codec interface.
255 // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
256 // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
257 func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
258 if !val.CanSet() || val.Kind() != reflect.Struct {
259 return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
260 }
261
262 switch vrType := vr.Type(); vrType {
263 case bsontype.Type(0), bsontype.EmbeddedDocument:
264 case bsontype.Null:
265 if err := vr.ReadNull(); err != nil {
266 return err
267 }
268
269 val.Set(reflect.Zero(val.Type()))
270 return nil
271 case bsontype.Undefined:
272 if err := vr.ReadUndefined(); err != nil {
273 return err
274 }
275
276 val.Set(reflect.Zero(val.Type()))
277 return nil
278 default:
279 return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
280 }
281
282 sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false)
283 if err != nil {
284 return err
285 }
286
287 if sc.DecodeZeroStruct || dc.zeroStructs {
288 val.Set(reflect.Zero(val.Type()))
289 }
290 if sc.DecodeDeepZeroInline && sd.inline {
291 val.Set(deepZero(val.Type()))
292 }
293
294 var decoder ValueDecoder
295 var inlineMap reflect.Value
296 if sd.inlineMap >= 0 {
297 inlineMap = val.Field(sd.inlineMap)
298 decoder, err = dc.LookupDecoder(inlineMap.Type().Elem())
299 if err != nil {
300 return err
301 }
302 }
303
304 dr, err := vr.ReadDocument()
305 if err != nil {
306 return err
307 }
308
309 for {
310 name, vr, err := dr.ReadElement()
311 if err == bsonrw.ErrEOD {
312 break
313 }
314 if err != nil {
315 return err
316 }
317
318 fd, exists := sd.fm[name]
319 if !exists {
320 // if the original name isn't found in the struct description, try again with the name in lowercase
321 // this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
322 // names
323 fd, exists = sd.fm[strings.ToLower(name)]
324 }
325
326 if !exists {
327 if sd.inlineMap < 0 {
328 // The encoding/json package requires a flag to return on error for non-existent fields.
329 // This functionality seems appropriate for the struct codec.
330 err = vr.Skip()
331 if err != nil {
332 return err
333 }
334 continue
335 }
336
337 if inlineMap.IsNil() {
338 inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
339 }
340
341 elem := reflect.New(inlineMap.Type().Elem()).Elem()
342 dc.Ancestor = inlineMap.Type()
343 err = decoder.DecodeValue(dc, vr, elem)
344 if err != nil {
345 return err
346 }
347 inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
348 continue
349 }
350
351 var field reflect.Value
352 if fd.inline == nil {
353 field = val.Field(fd.idx)
354 } else {
355 field, err = getInlineField(val, fd.inline)
356 if err != nil {
357 return err
358 }
359 }
360
361 if !field.CanSet() { // Being settable is a super set of being addressable.
362 innerErr := fmt.Errorf("field %v is not settable", field)
363 return newDecodeError(fd.name, innerErr)
364 }
365 if field.Kind() == reflect.Ptr && field.IsNil() {
366 field.Set(reflect.New(field.Type().Elem()))
367 }
368 field = field.Addr()
369
370 dctx := DecodeContext{
371 Registry: dc.Registry,
372 Truncate: fd.truncate || dc.Truncate,
373 defaultDocumentType: dc.defaultDocumentType,
374 binaryAsSlice: dc.binaryAsSlice,
375 useJSONStructTags: dc.useJSONStructTags,
376 useLocalTimeZone: dc.useLocalTimeZone,
377 zeroMaps: dc.zeroMaps,
378 zeroStructs: dc.zeroStructs,
379 }
380
381 if fd.decoder == nil {
382 return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
383 }
384
385 err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
386 if err != nil {
387 return newDecodeError(fd.name, err)
388 }
389 }
390
391 return nil
392 }
393
394 func isZero(v reflect.Value, omitZeroStruct bool) bool {
395 kind := v.Kind()
396 if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) {
397 return v.Interface().(Zeroer).IsZero()
398 }
399 if kind == reflect.Struct {
400 if !omitZeroStruct {
401 return false
402 }
403 vt := v.Type()
404 if vt == tTime {
405 return v.Interface().(time.Time).IsZero()
406 }
407 numField := vt.NumField()
408 for i := 0; i < numField; i++ {
409 ff := vt.Field(i)
410 if ff.PkgPath != "" && !ff.Anonymous {
411 continue // Private field
412 }
413 if !isZero(v.Field(i), omitZeroStruct) {
414 return false
415 }
416 }
417 return true
418 }
419 return !v.IsValid() || v.IsZero()
420 }
421
422 type structDescription struct {
423 fm map[string]fieldDescription
424 fl []fieldDescription
425 inlineMap int
426 inline bool
427 }
428
429 type fieldDescription struct {
430 name string // BSON key name
431 fieldName string // struct field name
432 idx int
433 omitEmpty bool
434 minSize bool
435 truncate bool
436 inline []int
437 encoder ValueEncoder
438 decoder ValueDecoder
439 }
440
441 type byIndex []fieldDescription
442
443 func (bi byIndex) Len() int { return len(bi) }
444
445 func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
446
447 func (bi byIndex) Less(i, j int) bool {
448 // If a field is inlined, its index in the top level struct is stored at inline[0]
449 iIdx, jIdx := bi[i].idx, bi[j].idx
450 if len(bi[i].inline) > 0 {
451 iIdx = bi[i].inline[0]
452 }
453 if len(bi[j].inline) > 0 {
454 jIdx = bi[j].inline[0]
455 }
456 if iIdx != jIdx {
457 return iIdx < jIdx
458 }
459 for k, biik := range bi[i].inline {
460 if k >= len(bi[j].inline) {
461 return false
462 }
463 if biik != bi[j].inline[k] {
464 return biik < bi[j].inline[k]
465 }
466 }
467 return len(bi[i].inline) < len(bi[j].inline)
468 }
469
470 func (sc *StructCodec) describeStruct(
471 r *Registry,
472 t reflect.Type,
473 useJSONStructTags bool,
474 errorOnDuplicates bool,
475 ) (*structDescription, error) {
476 // We need to analyze the struct, including getting the tags, collecting
477 // information about inlining, and create a map of the field name to the field.
478 if v, ok := sc.cache.Load(t); ok {
479 return v.(*structDescription), nil
480 }
481 // TODO(charlie): Only describe the struct once when called
482 // concurrently with the same type.
483 ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates)
484 if err != nil {
485 return nil, err
486 }
487 if v, loaded := sc.cache.LoadOrStore(t, ds); loaded {
488 ds = v.(*structDescription)
489 }
490 return ds, nil
491 }
492
493 func (sc *StructCodec) describeStructSlow(
494 r *Registry,
495 t reflect.Type,
496 useJSONStructTags bool,
497 errorOnDuplicates bool,
498 ) (*structDescription, error) {
499 numFields := t.NumField()
500 sd := &structDescription{
501 fm: make(map[string]fieldDescription, numFields),
502 fl: make([]fieldDescription, 0, numFields),
503 inlineMap: -1,
504 }
505
506 var fields []fieldDescription
507 for i := 0; i < numFields; i++ {
508 sf := t.Field(i)
509 if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
510 // field is private or unexported fields aren't allowed, ignore
511 continue
512 }
513
514 sfType := sf.Type
515 encoder, err := r.LookupEncoder(sfType)
516 if err != nil {
517 encoder = nil
518 }
519 decoder, err := r.LookupDecoder(sfType)
520 if err != nil {
521 decoder = nil
522 }
523
524 description := fieldDescription{
525 fieldName: sf.Name,
526 idx: i,
527 encoder: encoder,
528 decoder: decoder,
529 }
530
531 var stags StructTags
532 // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser
533 // instead of the parser defined on the codec.
534 if useJSONStructTags {
535 stags, err = JSONFallbackStructTagParser.ParseStructTags(sf)
536 } else {
537 stags, err = sc.parser.ParseStructTags(sf)
538 }
539 if err != nil {
540 return nil, err
541 }
542 if stags.Skip {
543 continue
544 }
545 description.name = stags.Name
546 description.omitEmpty = stags.OmitEmpty
547 description.minSize = stags.MinSize
548 description.truncate = stags.Truncate
549
550 if stags.Inline {
551 sd.inline = true
552 switch sfType.Kind() {
553 case reflect.Map:
554 if sd.inlineMap >= 0 {
555 return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
556 }
557 if sfType.Key() != tString {
558 return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
559 }
560 sd.inlineMap = description.idx
561 case reflect.Ptr:
562 sfType = sfType.Elem()
563 if sfType.Kind() != reflect.Struct {
564 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
565 }
566 fallthrough
567 case reflect.Struct:
568 inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates)
569 if err != nil {
570 return nil, err
571 }
572 for _, fd := range inlinesf.fl {
573 if fd.inline == nil {
574 fd.inline = []int{i, fd.idx}
575 } else {
576 fd.inline = append([]int{i}, fd.inline...)
577 }
578 fields = append(fields, fd)
579
580 }
581 default:
582 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
583 }
584 continue
585 }
586 fields = append(fields, description)
587 }
588
589 // Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
590 sort.Slice(fields, func(i, j int) bool {
591 x := fields
592 // sort field by name, breaking ties with depth, then
593 // breaking ties with index sequence.
594 if x[i].name != x[j].name {
595 return x[i].name < x[j].name
596 }
597 if len(x[i].inline) != len(x[j].inline) {
598 return len(x[i].inline) < len(x[j].inline)
599 }
600 return byIndex(x).Less(i, j)
601 })
602
603 for advance, i := 0, 0; i < len(fields); i += advance {
604 // One iteration per name.
605 // Find the sequence of fields with the name of this first field.
606 fi := fields[i]
607 name := fi.name
608 for advance = 1; i+advance < len(fields); advance++ {
609 fj := fields[i+advance]
610 if fj.name != name {
611 break
612 }
613 }
614 if advance == 1 { // Only one field with this name
615 sd.fl = append(sd.fl, fi)
616 sd.fm[name] = fi
617 continue
618 }
619 dominant, ok := dominantField(fields[i : i+advance])
620 if !ok || !sc.OverwriteDuplicatedInlinedFields || errorOnDuplicates {
621 return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
622 }
623 sd.fl = append(sd.fl, dominant)
624 sd.fm[name] = dominant
625 }
626
627 sort.Sort(byIndex(sd.fl))
628
629 return sd, nil
630 }
631
632 // dominantField looks through the fields, all of which are known to
633 // have the same name, to find the single field that dominates the
634 // others using Go's inlining rules. If there are multiple top-level
635 // fields, the boolean will be false: This condition is an error in Go
636 // and we skip all the fields.
637 func dominantField(fields []fieldDescription) (fieldDescription, bool) {
638 // The fields are sorted in increasing index-length order, then by presence of tag.
639 // That means that the first field is the dominant one. We need only check
640 // for error cases: two fields at top level.
641 if len(fields) > 1 &&
642 len(fields[0].inline) == len(fields[1].inline) {
643 return fieldDescription{}, false
644 }
645 return fields[0], true
646 }
647
648 func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
649 defer func() {
650 if recovered := recover(); recovered != nil {
651 switch r := recovered.(type) {
652 case string:
653 err = fmt.Errorf("%s", r)
654 case error:
655 err = r
656 }
657 }
658 }()
659
660 result = v.FieldByIndex(index)
661 return
662 }
663
664 func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
665 field, err := fieldByIndexErr(val, index)
666 if err == nil {
667 return field, nil
668 }
669
670 // if parent of this element doesn't exist, fix its parent
671 inlineParent := index[:len(index)-1]
672 var fParent reflect.Value
673 if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
674 fParent, err = getInlineField(val, inlineParent)
675 if err != nil {
676 return fParent, err
677 }
678 }
679 fParent.Set(reflect.New(fParent.Type().Elem()))
680
681 return fieldByIndexErr(val, index)
682 }
683
684 // DeepZero returns recursive zero object
685 func deepZero(st reflect.Type) (result reflect.Value) {
686 if st.Kind() == reflect.Struct {
687 numField := st.NumField()
688 for i := 0; i < numField; i++ {
689 if result == emptyValue {
690 result = reflect.Indirect(reflect.New(st))
691 }
692 f := result.Field(i)
693 if f.CanInterface() {
694 if f.Type().Kind() == reflect.Struct {
695 result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
696 }
697 }
698 }
699 }
700 return result
701 }
702
703 // recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
704 func recursivePointerTo(v reflect.Value) reflect.Value {
705 v = reflect.Indirect(v)
706 result := reflect.New(v.Type())
707 if v.Kind() == reflect.Struct {
708 for i := 0; i < v.NumField(); i++ {
709 if f := v.Field(i); f.Kind() == reflect.Ptr {
710 if f.Elem().Kind() == reflect.Struct {
711 result.Elem().Field(i).Set(recursivePointerTo(f))
712 }
713 }
714 }
715 }
716
717 return result
718 }
719