uint_codec.go raw

   1  // Copyright (C) MongoDB, Inc. 2017-present.
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
   4  // not use this file except in compliance with the License. You may obtain
   5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
   6  
   7  package bsoncodec
   8  
   9  import (
  10  	"fmt"
  11  	"math"
  12  	"reflect"
  13  
  14  	"go.mongodb.org/mongo-driver/bson/bsonoptions"
  15  	"go.mongodb.org/mongo-driver/bson/bsonrw"
  16  	"go.mongodb.org/mongo-driver/bson/bsontype"
  17  )
  18  
  19  // UIntCodec is the Codec used for uint values.
  20  //
  21  // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
  22  // UIntCodec registered.
  23  type UIntCodec struct {
  24  	// EncodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the
  25  	// minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value.
  26  	//
  27  	// Deprecated: Use bson.Encoder.IntMinSize instead.
  28  	EncodeToMinSize bool
  29  }
  30  
  31  var (
  32  	defaultUIntCodec = NewUIntCodec()
  33  
  34  	// Assert that defaultUIntCodec satisfies the typeDecoder interface, which allows it to be used
  35  	// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection.
  36  	_ typeDecoder = defaultUIntCodec
  37  )
  38  
  39  // NewUIntCodec returns a UIntCodec with options opts.
  40  //
  41  // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
  42  // UIntCodec registered.
  43  func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec {
  44  	uintOpt := bsonoptions.MergeUIntCodecOptions(opts...)
  45  
  46  	codec := UIntCodec{}
  47  	if uintOpt.EncodeToMinSize != nil {
  48  		codec.EncodeToMinSize = *uintOpt.EncodeToMinSize
  49  	}
  50  	return &codec
  51  }
  52  
  53  // EncodeValue is the ValueEncoder for uint types.
  54  func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
  55  	switch val.Kind() {
  56  	case reflect.Uint8, reflect.Uint16:
  57  		return vw.WriteInt32(int32(val.Uint()))
  58  	case reflect.Uint, reflect.Uint32, reflect.Uint64:
  59  		u64 := val.Uint()
  60  
  61  		// If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32
  62  		useMinSize := ec.MinSize || (uic.EncodeToMinSize && val.Kind() != reflect.Uint64)
  63  
  64  		if u64 <= math.MaxInt32 && useMinSize {
  65  			return vw.WriteInt32(int32(u64))
  66  		}
  67  		if u64 > math.MaxInt64 {
  68  			return fmt.Errorf("%d overflows int64", u64)
  69  		}
  70  		return vw.WriteInt64(int64(u64))
  71  	}
  72  
  73  	return ValueEncoderError{
  74  		Name:     "UintEncodeValue",
  75  		Kinds:    []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
  76  		Received: val,
  77  	}
  78  }
  79  
  80  func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
  81  	var i64 int64
  82  	var err error
  83  	switch vrType := vr.Type(); vrType {
  84  	case bsontype.Int32:
  85  		i32, err := vr.ReadInt32()
  86  		if err != nil {
  87  			return emptyValue, err
  88  		}
  89  		i64 = int64(i32)
  90  	case bsontype.Int64:
  91  		i64, err = vr.ReadInt64()
  92  		if err != nil {
  93  			return emptyValue, err
  94  		}
  95  	case bsontype.Double:
  96  		f64, err := vr.ReadDouble()
  97  		if err != nil {
  98  			return emptyValue, err
  99  		}
 100  		if !dc.Truncate && math.Floor(f64) != f64 {
 101  			return emptyValue, errCannotTruncate
 102  		}
 103  		if f64 > float64(math.MaxInt64) {
 104  			return emptyValue, fmt.Errorf("%g overflows int64", f64)
 105  		}
 106  		i64 = int64(f64)
 107  	case bsontype.Boolean:
 108  		b, err := vr.ReadBoolean()
 109  		if err != nil {
 110  			return emptyValue, err
 111  		}
 112  		if b {
 113  			i64 = 1
 114  		}
 115  	case bsontype.Null:
 116  		if err = vr.ReadNull(); err != nil {
 117  			return emptyValue, err
 118  		}
 119  	case bsontype.Undefined:
 120  		if err = vr.ReadUndefined(); err != nil {
 121  			return emptyValue, err
 122  		}
 123  	default:
 124  		return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType)
 125  	}
 126  
 127  	switch t.Kind() {
 128  	case reflect.Uint8:
 129  		if i64 < 0 || i64 > math.MaxUint8 {
 130  			return emptyValue, fmt.Errorf("%d overflows uint8", i64)
 131  		}
 132  
 133  		return reflect.ValueOf(uint8(i64)), nil
 134  	case reflect.Uint16:
 135  		if i64 < 0 || i64 > math.MaxUint16 {
 136  			return emptyValue, fmt.Errorf("%d overflows uint16", i64)
 137  		}
 138  
 139  		return reflect.ValueOf(uint16(i64)), nil
 140  	case reflect.Uint32:
 141  		if i64 < 0 || i64 > math.MaxUint32 {
 142  			return emptyValue, fmt.Errorf("%d overflows uint32", i64)
 143  		}
 144  
 145  		return reflect.ValueOf(uint32(i64)), nil
 146  	case reflect.Uint64:
 147  		if i64 < 0 {
 148  			return emptyValue, fmt.Errorf("%d overflows uint64", i64)
 149  		}
 150  
 151  		return reflect.ValueOf(uint64(i64)), nil
 152  	case reflect.Uint:
 153  		if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint
 154  			return emptyValue, fmt.Errorf("%d overflows uint", i64)
 155  		}
 156  
 157  		return reflect.ValueOf(uint(i64)), nil
 158  	default:
 159  		return emptyValue, ValueDecoderError{
 160  			Name:     "UintDecodeValue",
 161  			Kinds:    []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
 162  			Received: reflect.Zero(t),
 163  		}
 164  	}
 165  }
 166  
 167  // DecodeValue is the ValueDecoder for uint types.
 168  func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
 169  	if !val.CanSet() {
 170  		return ValueDecoderError{
 171  			Name:     "UintDecodeValue",
 172  			Kinds:    []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
 173  			Received: val,
 174  		}
 175  	}
 176  
 177  	elem, err := uic.decodeType(dc, vr, val.Type())
 178  	if err != nil {
 179  		return err
 180  	}
 181  
 182  	val.SetUint(elem.Uint())
 183  	return nil
 184  }
 185