map_codec.go raw
1 // Copyright (C) MongoDB, Inc. 2017-present.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may
4 // not use this file except in compliance with the License. You may obtain
5 // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7 package bsoncodec
8
9 import (
10 "encoding"
11 "fmt"
12 "reflect"
13 "strconv"
14
15 "go.mongodb.org/mongo-driver/bson/bsonoptions"
16 "go.mongodb.org/mongo-driver/bson/bsonrw"
17 "go.mongodb.org/mongo-driver/bson/bsontype"
18 )
19
20 var defaultMapCodec = NewMapCodec()
21
22 // MapCodec is the Codec used for map values.
23 //
24 // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
25 // MapCodec registered.
26 type MapCodec struct {
27 // DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination
28 // value passed to Decode before unmarshaling BSON documents into them.
29 //
30 // Deprecated: Use bson.Decoder.ZeroMaps instead.
31 DecodeZerosMap bool
32
33 // EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of
34 // BSON null.
35 //
36 // Deprecated: Use bson.Encoder.NilMapAsEmpty instead.
37 EncodeNilAsEmpty bool
38
39 // EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name
40 // strings using fmt.Sprintf() instead of the default string conversion logic.
41 //
42 // Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt instead.
43 EncodeKeysWithStringer bool
44 }
45
46 // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
47 // This applies to types used as map keys and is similar to encoding.TextMarshaler.
48 type KeyMarshaler interface {
49 MarshalKey() (key string, err error)
50 }
51
52 // KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
53 // of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
54 //
55 // UnmarshalKey must be able to decode the form generated by MarshalKey.
56 // UnmarshalKey must copy the text if it wishes to retain the text
57 // after returning.
58 type KeyUnmarshaler interface {
59 UnmarshalKey(key string) error
60 }
61
62 // NewMapCodec returns a MapCodec with options opts.
63 //
64 // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
65 // MapCodec registered.
66 func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
67 mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
68
69 codec := MapCodec{}
70 if mapOpt.DecodeZerosMap != nil {
71 codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
72 }
73 if mapOpt.EncodeNilAsEmpty != nil {
74 codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
75 }
76 if mapOpt.EncodeKeysWithStringer != nil {
77 codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
78 }
79 return &codec
80 }
81
82 // EncodeValue is the ValueEncoder for map[*]* types.
83 func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
84 if !val.IsValid() || val.Kind() != reflect.Map {
85 return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
86 }
87
88 if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty {
89 // If we have a nil map but we can't WriteNull, that means we're probably trying to encode
90 // to a TopLevel document. We can't currently tell if this is what actually happened, but if
91 // there's a deeper underlying problem, the error will also be returned from WriteDocument,
92 // so just continue. The operations on a map reflection value are valid, so we can call
93 // MapKeys within mapEncodeValue without a problem.
94 err := vw.WriteNull()
95 if err == nil {
96 return nil
97 }
98 }
99
100 dw, err := vw.WriteDocument()
101 if err != nil {
102 return err
103 }
104
105 return mc.mapEncodeValue(ec, dw, val, nil)
106 }
107
108 // mapEncodeValue handles encoding of the values of a map. The collisionFn returns
109 // true if the provided key exists, this is mainly used for inline maps in the
110 // struct codec.
111 func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
112
113 elemType := val.Type().Elem()
114 encoder, err := ec.LookupEncoder(elemType)
115 if err != nil && elemType.Kind() != reflect.Interface {
116 return err
117 }
118
119 keys := val.MapKeys()
120 for _, key := range keys {
121 keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
122 if err != nil {
123 return err
124 }
125
126 if collisionFn != nil && collisionFn(keyStr) {
127 return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
128 }
129
130 currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
131 if lookupErr != nil && lookupErr != errInvalidValue {
132 return lookupErr
133 }
134
135 vw, err := dw.WriteDocumentElement(keyStr)
136 if err != nil {
137 return err
138 }
139
140 if lookupErr == errInvalidValue {
141 err = vw.WriteNull()
142 if err != nil {
143 return err
144 }
145 continue
146 }
147
148 err = currEncoder.EncodeValue(ec, vw, currVal)
149 if err != nil {
150 return err
151 }
152 }
153
154 return dw.WriteDocumentEnd()
155 }
156
157 // DecodeValue is the ValueDecoder for map[string/decimal]* types.
158 func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
159 if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
160 return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
161 }
162
163 switch vrType := vr.Type(); vrType {
164 case bsontype.Type(0), bsontype.EmbeddedDocument:
165 case bsontype.Null:
166 val.Set(reflect.Zero(val.Type()))
167 return vr.ReadNull()
168 case bsontype.Undefined:
169 val.Set(reflect.Zero(val.Type()))
170 return vr.ReadUndefined()
171 default:
172 return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
173 }
174
175 dr, err := vr.ReadDocument()
176 if err != nil {
177 return err
178 }
179
180 if val.IsNil() {
181 val.Set(reflect.MakeMap(val.Type()))
182 }
183
184 if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) {
185 clearMap(val)
186 }
187
188 eType := val.Type().Elem()
189 decoder, err := dc.LookupDecoder(eType)
190 if err != nil {
191 return err
192 }
193 eTypeDecoder, _ := decoder.(typeDecoder)
194
195 if eType == tEmpty {
196 dc.Ancestor = val.Type()
197 }
198
199 keyType := val.Type().Key()
200
201 for {
202 key, vr, err := dr.ReadElement()
203 if err == bsonrw.ErrEOD {
204 break
205 }
206 if err != nil {
207 return err
208 }
209
210 k, err := mc.decodeKey(key, keyType)
211 if err != nil {
212 return err
213 }
214
215 elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
216 if err != nil {
217 return newDecodeError(key, err)
218 }
219
220 val.SetMapIndex(k, elem)
221 }
222 return nil
223 }
224
225 func clearMap(m reflect.Value) {
226 var none reflect.Value
227 for _, k := range m.MapKeys() {
228 m.SetMapIndex(k, none)
229 }
230 }
231
232 func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
233 if mc.EncodeKeysWithStringer || encodeKeysWithStringer {
234 return fmt.Sprint(val), nil
235 }
236
237 // keys of any string type are used directly
238 if val.Kind() == reflect.String {
239 return val.String(), nil
240 }
241 // KeyMarshalers are marshaled
242 if km, ok := val.Interface().(KeyMarshaler); ok {
243 if val.Kind() == reflect.Ptr && val.IsNil() {
244 return "", nil
245 }
246 buf, err := km.MarshalKey()
247 if err == nil {
248 return buf, nil
249 }
250 return "", err
251 }
252 // keys implement encoding.TextMarshaler are marshaled.
253 if km, ok := val.Interface().(encoding.TextMarshaler); ok {
254 if val.Kind() == reflect.Ptr && val.IsNil() {
255 return "", nil
256 }
257
258 buf, err := km.MarshalText()
259 if err != nil {
260 return "", err
261 }
262
263 return string(buf), nil
264 }
265
266 switch val.Kind() {
267 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
268 return strconv.FormatInt(val.Int(), 10), nil
269 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
270 return strconv.FormatUint(val.Uint(), 10), nil
271 }
272 return "", fmt.Errorf("unsupported key type: %v", val.Type())
273 }
274
275 var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
276 var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
277
278 func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
279 keyVal := reflect.ValueOf(key)
280 var err error
281 switch {
282 // First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
283 case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
284 keyVal = reflect.New(keyType)
285 v := keyVal.Interface().(KeyUnmarshaler)
286 err = v.UnmarshalKey(key)
287 keyVal = keyVal.Elem()
288 // Try to decode encoding.TextUnmarshalers.
289 case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
290 keyVal = reflect.New(keyType)
291 v := keyVal.Interface().(encoding.TextUnmarshaler)
292 err = v.UnmarshalText([]byte(key))
293 keyVal = keyVal.Elem()
294 // Otherwise, go to type specific behavior
295 default:
296 switch keyType.Kind() {
297 case reflect.String:
298 keyVal = reflect.ValueOf(key).Convert(keyType)
299 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
300 n, parseErr := strconv.ParseInt(key, 10, 64)
301 if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
302 err = fmt.Errorf("failed to unmarshal number key %v", key)
303 }
304 keyVal = reflect.ValueOf(n).Convert(keyType)
305 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
306 n, parseErr := strconv.ParseUint(key, 10, 64)
307 if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
308 err = fmt.Errorf("failed to unmarshal number key %v", key)
309 break
310 }
311 keyVal = reflect.ValueOf(n).Convert(keyType)
312 case reflect.Float32, reflect.Float64:
313 if mc.EncodeKeysWithStringer {
314 parsed, err := strconv.ParseFloat(key, 64)
315 if err != nil {
316 return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
317 }
318 keyVal = reflect.ValueOf(parsed)
319 break
320 }
321 fallthrough
322 default:
323 return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
324 }
325 }
326 return keyVal, err
327 }
328