map_codec.go raw

   1  // Copyright (C) MongoDB, Inc. 2017-present.
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
   4  // not use this file except in compliance with the License. You may obtain
   5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
   6  
   7  package bsoncodec
   8  
   9  import (
  10  	"encoding"
  11  	"fmt"
  12  	"reflect"
  13  	"strconv"
  14  
  15  	"go.mongodb.org/mongo-driver/bson/bsonoptions"
  16  	"go.mongodb.org/mongo-driver/bson/bsonrw"
  17  	"go.mongodb.org/mongo-driver/bson/bsontype"
  18  )
  19  
  20  var defaultMapCodec = NewMapCodec()
  21  
  22  // MapCodec is the Codec used for map values.
  23  //
  24  // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
  25  // MapCodec registered.
  26  type MapCodec struct {
  27  	// DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination
  28  	// value passed to Decode before unmarshaling BSON documents into them.
  29  	//
  30  	// Deprecated: Use bson.Decoder.ZeroMaps instead.
  31  	DecodeZerosMap bool
  32  
  33  	// EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of
  34  	// BSON null.
  35  	//
  36  	// Deprecated: Use bson.Encoder.NilMapAsEmpty instead.
  37  	EncodeNilAsEmpty bool
  38  
  39  	// EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name
  40  	// strings using fmt.Sprintf() instead of the default string conversion logic.
  41  	//
  42  	// Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt instead.
  43  	EncodeKeysWithStringer bool
  44  }
  45  
  46  // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
  47  // This applies to types used as map keys and is similar to encoding.TextMarshaler.
  48  type KeyMarshaler interface {
  49  	MarshalKey() (key string, err error)
  50  }
  51  
  52  // KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
  53  // of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
  54  //
  55  // UnmarshalKey must be able to decode the form generated by MarshalKey.
  56  // UnmarshalKey must copy the text if it wishes to retain the text
  57  // after returning.
  58  type KeyUnmarshaler interface {
  59  	UnmarshalKey(key string) error
  60  }
  61  
  62  // NewMapCodec returns a MapCodec with options opts.
  63  //
  64  // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
  65  // MapCodec registered.
  66  func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
  67  	mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
  68  
  69  	codec := MapCodec{}
  70  	if mapOpt.DecodeZerosMap != nil {
  71  		codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
  72  	}
  73  	if mapOpt.EncodeNilAsEmpty != nil {
  74  		codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
  75  	}
  76  	if mapOpt.EncodeKeysWithStringer != nil {
  77  		codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
  78  	}
  79  	return &codec
  80  }
  81  
  82  // EncodeValue is the ValueEncoder for map[*]* types.
  83  func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
  84  	if !val.IsValid() || val.Kind() != reflect.Map {
  85  		return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
  86  	}
  87  
  88  	if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty {
  89  		// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
  90  		// to a TopLevel document. We can't currently tell if this is what actually happened, but if
  91  		// there's a deeper underlying problem, the error will also be returned from WriteDocument,
  92  		// so just continue. The operations on a map reflection value are valid, so we can call
  93  		// MapKeys within mapEncodeValue without a problem.
  94  		err := vw.WriteNull()
  95  		if err == nil {
  96  			return nil
  97  		}
  98  	}
  99  
 100  	dw, err := vw.WriteDocument()
 101  	if err != nil {
 102  		return err
 103  	}
 104  
 105  	return mc.mapEncodeValue(ec, dw, val, nil)
 106  }
 107  
 108  // mapEncodeValue handles encoding of the values of a map. The collisionFn returns
 109  // true if the provided key exists, this is mainly used for inline maps in the
 110  // struct codec.
 111  func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
 112  
 113  	elemType := val.Type().Elem()
 114  	encoder, err := ec.LookupEncoder(elemType)
 115  	if err != nil && elemType.Kind() != reflect.Interface {
 116  		return err
 117  	}
 118  
 119  	keys := val.MapKeys()
 120  	for _, key := range keys {
 121  		keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
 122  		if err != nil {
 123  			return err
 124  		}
 125  
 126  		if collisionFn != nil && collisionFn(keyStr) {
 127  			return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
 128  		}
 129  
 130  		currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
 131  		if lookupErr != nil && lookupErr != errInvalidValue {
 132  			return lookupErr
 133  		}
 134  
 135  		vw, err := dw.WriteDocumentElement(keyStr)
 136  		if err != nil {
 137  			return err
 138  		}
 139  
 140  		if lookupErr == errInvalidValue {
 141  			err = vw.WriteNull()
 142  			if err != nil {
 143  				return err
 144  			}
 145  			continue
 146  		}
 147  
 148  		err = currEncoder.EncodeValue(ec, vw, currVal)
 149  		if err != nil {
 150  			return err
 151  		}
 152  	}
 153  
 154  	return dw.WriteDocumentEnd()
 155  }
 156  
 157  // DecodeValue is the ValueDecoder for map[string/decimal]* types.
 158  func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
 159  	if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
 160  		return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
 161  	}
 162  
 163  	switch vrType := vr.Type(); vrType {
 164  	case bsontype.Type(0), bsontype.EmbeddedDocument:
 165  	case bsontype.Null:
 166  		val.Set(reflect.Zero(val.Type()))
 167  		return vr.ReadNull()
 168  	case bsontype.Undefined:
 169  		val.Set(reflect.Zero(val.Type()))
 170  		return vr.ReadUndefined()
 171  	default:
 172  		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
 173  	}
 174  
 175  	dr, err := vr.ReadDocument()
 176  	if err != nil {
 177  		return err
 178  	}
 179  
 180  	if val.IsNil() {
 181  		val.Set(reflect.MakeMap(val.Type()))
 182  	}
 183  
 184  	if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) {
 185  		clearMap(val)
 186  	}
 187  
 188  	eType := val.Type().Elem()
 189  	decoder, err := dc.LookupDecoder(eType)
 190  	if err != nil {
 191  		return err
 192  	}
 193  	eTypeDecoder, _ := decoder.(typeDecoder)
 194  
 195  	if eType == tEmpty {
 196  		dc.Ancestor = val.Type()
 197  	}
 198  
 199  	keyType := val.Type().Key()
 200  
 201  	for {
 202  		key, vr, err := dr.ReadElement()
 203  		if err == bsonrw.ErrEOD {
 204  			break
 205  		}
 206  		if err != nil {
 207  			return err
 208  		}
 209  
 210  		k, err := mc.decodeKey(key, keyType)
 211  		if err != nil {
 212  			return err
 213  		}
 214  
 215  		elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
 216  		if err != nil {
 217  			return newDecodeError(key, err)
 218  		}
 219  
 220  		val.SetMapIndex(k, elem)
 221  	}
 222  	return nil
 223  }
 224  
 225  func clearMap(m reflect.Value) {
 226  	var none reflect.Value
 227  	for _, k := range m.MapKeys() {
 228  		m.SetMapIndex(k, none)
 229  	}
 230  }
 231  
 232  func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
 233  	if mc.EncodeKeysWithStringer || encodeKeysWithStringer {
 234  		return fmt.Sprint(val), nil
 235  	}
 236  
 237  	// keys of any string type are used directly
 238  	if val.Kind() == reflect.String {
 239  		return val.String(), nil
 240  	}
 241  	// KeyMarshalers are marshaled
 242  	if km, ok := val.Interface().(KeyMarshaler); ok {
 243  		if val.Kind() == reflect.Ptr && val.IsNil() {
 244  			return "", nil
 245  		}
 246  		buf, err := km.MarshalKey()
 247  		if err == nil {
 248  			return buf, nil
 249  		}
 250  		return "", err
 251  	}
 252  	// keys implement encoding.TextMarshaler are marshaled.
 253  	if km, ok := val.Interface().(encoding.TextMarshaler); ok {
 254  		if val.Kind() == reflect.Ptr && val.IsNil() {
 255  			return "", nil
 256  		}
 257  
 258  		buf, err := km.MarshalText()
 259  		if err != nil {
 260  			return "", err
 261  		}
 262  
 263  		return string(buf), nil
 264  	}
 265  
 266  	switch val.Kind() {
 267  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 268  		return strconv.FormatInt(val.Int(), 10), nil
 269  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
 270  		return strconv.FormatUint(val.Uint(), 10), nil
 271  	}
 272  	return "", fmt.Errorf("unsupported key type: %v", val.Type())
 273  }
 274  
 275  var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
 276  var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
 277  
 278  func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
 279  	keyVal := reflect.ValueOf(key)
 280  	var err error
 281  	switch {
 282  	// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
 283  	case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
 284  		keyVal = reflect.New(keyType)
 285  		v := keyVal.Interface().(KeyUnmarshaler)
 286  		err = v.UnmarshalKey(key)
 287  		keyVal = keyVal.Elem()
 288  	// Try to decode encoding.TextUnmarshalers.
 289  	case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
 290  		keyVal = reflect.New(keyType)
 291  		v := keyVal.Interface().(encoding.TextUnmarshaler)
 292  		err = v.UnmarshalText([]byte(key))
 293  		keyVal = keyVal.Elem()
 294  	// Otherwise, go to type specific behavior
 295  	default:
 296  		switch keyType.Kind() {
 297  		case reflect.String:
 298  			keyVal = reflect.ValueOf(key).Convert(keyType)
 299  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 300  			n, parseErr := strconv.ParseInt(key, 10, 64)
 301  			if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
 302  				err = fmt.Errorf("failed to unmarshal number key %v", key)
 303  			}
 304  			keyVal = reflect.ValueOf(n).Convert(keyType)
 305  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
 306  			n, parseErr := strconv.ParseUint(key, 10, 64)
 307  			if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
 308  				err = fmt.Errorf("failed to unmarshal number key %v", key)
 309  				break
 310  			}
 311  			keyVal = reflect.ValueOf(n).Convert(keyType)
 312  		case reflect.Float32, reflect.Float64:
 313  			if mc.EncodeKeysWithStringer {
 314  				parsed, err := strconv.ParseFloat(key, 64)
 315  				if err != nil {
 316  					return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
 317  				}
 318  				keyVal = reflect.ValueOf(parsed)
 319  				break
 320  			}
 321  			fallthrough
 322  		default:
 323  			return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
 324  		}
 325  	}
 326  	return keyVal, err
 327  }
 328