codec_map.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"reflect"
   9  	"sort"
  10  
  11  	"google.golang.org/protobuf/encoding/protowire"
  12  	"google.golang.org/protobuf/internal/errors"
  13  	"google.golang.org/protobuf/internal/genid"
  14  	"google.golang.org/protobuf/reflect/protoreflect"
  15  )
  16  
  17  type mapInfo struct {
  18  	goType     reflect.Type
  19  	keyWiretag uint64
  20  	valWiretag uint64
  21  	keyFuncs   valueCoderFuncs
  22  	valFuncs   valueCoderFuncs
  23  	keyZero    protoreflect.Value
  24  	keyKind    protoreflect.Kind
  25  	conv       *mapConverter
  26  }
  27  
  28  func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
  29  	// TODO: Consider generating specialized map coders.
  30  	keyField := fd.MapKey()
  31  	valField := fd.MapValue()
  32  	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
  33  	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
  34  	keyFuncs := encoderFuncsForValue(keyField)
  35  	valFuncs := encoderFuncsForValue(valField)
  36  	conv := newMapConverter(ft, fd)
  37  
  38  	mapi := &mapInfo{
  39  		goType:     ft,
  40  		keyWiretag: keyWiretag,
  41  		valWiretag: valWiretag,
  42  		keyFuncs:   keyFuncs,
  43  		valFuncs:   valFuncs,
  44  		keyZero:    keyField.Default(),
  45  		keyKind:    keyField.Kind(),
  46  		conv:       conv,
  47  	}
  48  	if valField.Kind() == protoreflect.MessageKind {
  49  		valueMessage = getMessageInfo(ft.Elem())
  50  	}
  51  
  52  	funcs = pointerCoderFuncs{
  53  		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
  54  			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
  55  		},
  56  		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  57  			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
  58  		},
  59  		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
  60  			mp := p.AsValueOf(ft)
  61  			if mp.Elem().IsNil() {
  62  				mp.Elem().Set(reflect.MakeMap(mapi.goType))
  63  			}
  64  			if f.mi == nil {
  65  				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
  66  			} else {
  67  				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
  68  			}
  69  		},
  70  	}
  71  	switch valField.Kind() {
  72  	case protoreflect.MessageKind:
  73  		funcs.merge = mergeMapOfMessage
  74  	case protoreflect.BytesKind:
  75  		funcs.merge = mergeMapOfBytes
  76  	default:
  77  		funcs.merge = mergeMap
  78  	}
  79  	if valFuncs.isInit != nil {
  80  		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
  81  			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
  82  		}
  83  	}
  84  	return valueMessage, funcs
  85  }
  86  
  87  const (
  88  	mapKeyTagSize = 1 // field 1, tag size 1.
  89  	mapValTagSize = 1 // field 2, tag size 2.
  90  )
  91  
  92  func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
  93  	if mapv.Len() == 0 {
  94  		return 0
  95  	}
  96  	n := 0
  97  	iter := mapv.MapRange()
  98  	for iter.Next() {
  99  		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
 100  		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 101  		var valSize int
 102  		value := mapi.conv.valConv.PBValueOf(iter.Value())
 103  		if f.mi == nil {
 104  			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
 105  		} else {
 106  			p := pointerOfValue(iter.Value())
 107  			valSize += mapValTagSize
 108  			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
 109  		}
 110  		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
 111  	}
 112  	return n
 113  }
 114  
 115  func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 116  	if opts.depth--; opts.depth < 0 {
 117  		return out, errRecursionDepth
 118  	}
 119  	if wtyp != protowire.BytesType {
 120  		return out, errUnknown
 121  	}
 122  	b, n := protowire.ConsumeBytes(b)
 123  	if n < 0 {
 124  		return out, errDecode
 125  	}
 126  	var (
 127  		key = mapi.keyZero
 128  		val = mapi.conv.valConv.New()
 129  	)
 130  	for len(b) > 0 {
 131  		num, wtyp, n := protowire.ConsumeTag(b)
 132  		if n < 0 {
 133  			return out, errDecode
 134  		}
 135  		if num > protowire.MaxValidNumber {
 136  			return out, errDecode
 137  		}
 138  		b = b[n:]
 139  		err := errUnknown
 140  		switch num {
 141  		case genid.MapEntry_Key_field_number:
 142  			var v protoreflect.Value
 143  			var o unmarshalOutput
 144  			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
 145  			if err != nil {
 146  				break
 147  			}
 148  			key = v
 149  			n = o.n
 150  		case genid.MapEntry_Value_field_number:
 151  			var v protoreflect.Value
 152  			var o unmarshalOutput
 153  			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
 154  			if err != nil {
 155  				break
 156  			}
 157  			val = v
 158  			n = o.n
 159  		}
 160  		if err == errUnknown {
 161  			n = protowire.ConsumeFieldValue(num, wtyp, b)
 162  			if n < 0 {
 163  				return out, errDecode
 164  			}
 165  		} else if err != nil {
 166  			return out, err
 167  		}
 168  		b = b[n:]
 169  	}
 170  	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
 171  	out.n = n
 172  	return out, nil
 173  }
 174  
 175  func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 176  	if opts.depth--; opts.depth < 0 {
 177  		return out, errRecursionDepth
 178  	}
 179  	if wtyp != protowire.BytesType {
 180  		return out, errUnknown
 181  	}
 182  	b, n := protowire.ConsumeBytes(b)
 183  	if n < 0 {
 184  		return out, errDecode
 185  	}
 186  	var (
 187  		key = mapi.keyZero
 188  		val = reflect.New(f.mi.GoReflectType.Elem())
 189  	)
 190  	for len(b) > 0 {
 191  		num, wtyp, n := protowire.ConsumeTag(b)
 192  		if n < 0 {
 193  			return out, errDecode
 194  		}
 195  		if num > protowire.MaxValidNumber {
 196  			return out, errDecode
 197  		}
 198  		b = b[n:]
 199  		err := errUnknown
 200  		switch num {
 201  		case 1:
 202  			var v protoreflect.Value
 203  			var o unmarshalOutput
 204  			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
 205  			if err != nil {
 206  				break
 207  			}
 208  			key = v
 209  			n = o.n
 210  		case 2:
 211  			if wtyp != protowire.BytesType {
 212  				break
 213  			}
 214  			var v []byte
 215  			v, n = protowire.ConsumeBytes(b)
 216  			if n < 0 {
 217  				return out, errDecode
 218  			}
 219  			var o unmarshalOutput
 220  			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
 221  			if o.initialized {
 222  				// Consider this map item initialized so long as we see
 223  				// an initialized value.
 224  				out.initialized = true
 225  			}
 226  		}
 227  		if err == errUnknown {
 228  			n = protowire.ConsumeFieldValue(num, wtyp, b)
 229  			if n < 0 {
 230  				return out, errDecode
 231  			}
 232  		} else if err != nil {
 233  			return out, err
 234  		}
 235  		b = b[n:]
 236  	}
 237  	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
 238  	out.n = n
 239  	return out, nil
 240  }
 241  
 242  func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 243  	if f.mi == nil {
 244  		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
 245  		val := mapi.conv.valConv.PBValueOf(valrv)
 246  		size := 0
 247  		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 248  		size += mapi.valFuncs.size(val, mapValTagSize, opts)
 249  		b = protowire.AppendVarint(b, uint64(size))
 250  		before := len(b)
 251  		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
 252  		if err != nil {
 253  			return nil, err
 254  		}
 255  		b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
 256  		if measuredSize := len(b) - before; size != measuredSize && err == nil {
 257  			return nil, errors.MismatchedSizeCalculation(size, measuredSize)
 258  		}
 259  		return b, err
 260  	} else {
 261  		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
 262  		val := pointerOfValue(valrv)
 263  		valSize := f.mi.sizePointer(val, opts)
 264  		size := 0
 265  		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 266  		size += mapValTagSize + protowire.SizeBytes(valSize)
 267  		b = protowire.AppendVarint(b, uint64(size))
 268  		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
 269  		if err != nil {
 270  			return nil, err
 271  		}
 272  		b = protowire.AppendVarint(b, mapi.valWiretag)
 273  		b = protowire.AppendVarint(b, uint64(valSize))
 274  		before := len(b)
 275  		b, err = f.mi.marshalAppendPointer(b, val, opts)
 276  		if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
 277  			return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
 278  		}
 279  		return b, err
 280  	}
 281  }
 282  
 283  func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 284  	if mapv.Len() == 0 {
 285  		return b, nil
 286  	}
 287  	if opts.Deterministic() {
 288  		return appendMapDeterministic(b, mapv, mapi, f, opts)
 289  	}
 290  	iter := mapv.MapRange()
 291  	for iter.Next() {
 292  		var err error
 293  		b = protowire.AppendVarint(b, f.wiretag)
 294  		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
 295  		if err != nil {
 296  			return b, err
 297  		}
 298  	}
 299  	return b, nil
 300  }
 301  
 302  func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 303  	keys := mapv.MapKeys()
 304  	sort.Slice(keys, func(i, j int) bool {
 305  		switch keys[i].Kind() {
 306  		case reflect.Bool:
 307  			return !keys[i].Bool() && keys[j].Bool()
 308  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 309  			return keys[i].Int() < keys[j].Int()
 310  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
 311  			return keys[i].Uint() < keys[j].Uint()
 312  		case reflect.Float32, reflect.Float64:
 313  			return keys[i].Float() < keys[j].Float()
 314  		case reflect.String:
 315  			return keys[i].String() < keys[j].String()
 316  		default:
 317  			panic("invalid kind: " + keys[i].Kind().String())
 318  		}
 319  	})
 320  	for _, key := range keys {
 321  		var err error
 322  		b = protowire.AppendVarint(b, f.wiretag)
 323  		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
 324  		if err != nil {
 325  			return b, err
 326  		}
 327  	}
 328  	return b, nil
 329  }
 330  
 331  func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
 332  	if mi := f.mi; mi != nil {
 333  		mi.init()
 334  		if !mi.needsInitCheck {
 335  			return nil
 336  		}
 337  		iter := mapv.MapRange()
 338  		for iter.Next() {
 339  			val := pointerOfValue(iter.Value())
 340  			if err := mi.checkInitializedPointer(val); err != nil {
 341  				return err
 342  			}
 343  		}
 344  	} else {
 345  		iter := mapv.MapRange()
 346  		for iter.Next() {
 347  			val := mapi.conv.valConv.PBValueOf(iter.Value())
 348  			if err := mapi.valFuncs.isInit(val); err != nil {
 349  				return err
 350  			}
 351  		}
 352  	}
 353  	return nil
 354  }
 355  
 356  func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 357  	dstm := dst.AsValueOf(f.ft).Elem()
 358  	srcm := src.AsValueOf(f.ft).Elem()
 359  	if srcm.Len() == 0 {
 360  		return
 361  	}
 362  	if dstm.IsNil() {
 363  		dstm.Set(reflect.MakeMap(f.ft))
 364  	}
 365  	iter := srcm.MapRange()
 366  	for iter.Next() {
 367  		dstm.SetMapIndex(iter.Key(), iter.Value())
 368  	}
 369  }
 370  
 371  func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 372  	dstm := dst.AsValueOf(f.ft).Elem()
 373  	srcm := src.AsValueOf(f.ft).Elem()
 374  	if srcm.Len() == 0 {
 375  		return
 376  	}
 377  	if dstm.IsNil() {
 378  		dstm.Set(reflect.MakeMap(f.ft))
 379  	}
 380  	iter := srcm.MapRange()
 381  	for iter.Next() {
 382  		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
 383  	}
 384  }
 385  
 386  func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 387  	dstm := dst.AsValueOf(f.ft).Elem()
 388  	srcm := src.AsValueOf(f.ft).Elem()
 389  	if srcm.Len() == 0 {
 390  		return
 391  	}
 392  	if dstm.IsNil() {
 393  		dstm.Set(reflect.MakeMap(f.ft))
 394  	}
 395  	iter := srcm.MapRange()
 396  	for iter.Next() {
 397  		val := reflect.New(f.ft.Elem().Elem())
 398  		if f.mi != nil {
 399  			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
 400  		} else {
 401  			opts.Merge(asMessage(val), asMessage(iter.Value()))
 402  		}
 403  		dstm.SetMapIndex(iter.Key(), val)
 404  	}
 405  }
 406