convert.go raw

   1  // Copyright 2018 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"fmt"
   9  	"reflect"
  10  
  11  	"google.golang.org/protobuf/reflect/protoreflect"
  12  )
  13  
  14  // unwrapper unwraps the value to the underlying value.
  15  // This is implemented by List and Map.
  16  type unwrapper interface {
  17  	protoUnwrap() any
  18  }
  19  
  20  // A Converter coverts to/from Go reflect.Value types and protobuf protoreflect.Value types.
  21  type Converter interface {
  22  	// PBValueOf converts a reflect.Value to a protoreflect.Value.
  23  	PBValueOf(reflect.Value) protoreflect.Value
  24  
  25  	// GoValueOf converts a protoreflect.Value to a reflect.Value.
  26  	GoValueOf(protoreflect.Value) reflect.Value
  27  
  28  	// IsValidPB returns whether a protoreflect.Value is compatible with this type.
  29  	IsValidPB(protoreflect.Value) bool
  30  
  31  	// IsValidGo returns whether a reflect.Value is compatible with this type.
  32  	IsValidGo(reflect.Value) bool
  33  
  34  	// New returns a new field value.
  35  	// For scalars, it returns the default value of the field.
  36  	// For composite types, it returns a new mutable value.
  37  	New() protoreflect.Value
  38  
  39  	// Zero returns a new field value.
  40  	// For scalars, it returns the default value of the field.
  41  	// For composite types, it returns an immutable, empty value.
  42  	Zero() protoreflect.Value
  43  }
  44  
  45  // NewConverter matches a Go type with a protobuf field and returns a Converter
  46  // that converts between the two. Enums must be a named int32 kind that
  47  // implements protoreflect.Enum, and messages must be pointer to a named
  48  // struct type that implements protoreflect.ProtoMessage.
  49  //
  50  // This matcher deliberately supports a wider range of Go types than what
  51  // protoc-gen-go historically generated to be able to automatically wrap some
  52  // v1 messages generated by other forks of protoc-gen-go.
  53  func NewConverter(t reflect.Type, fd protoreflect.FieldDescriptor) Converter {
  54  	switch {
  55  	case fd.IsList():
  56  		return newListConverter(t, fd)
  57  	case fd.IsMap():
  58  		return newMapConverter(t, fd)
  59  	default:
  60  		return newSingularConverter(t, fd)
  61  	}
  62  }
  63  
  64  var (
  65  	boolType    = reflect.TypeOf(bool(false))
  66  	int32Type   = reflect.TypeOf(int32(0))
  67  	int64Type   = reflect.TypeOf(int64(0))
  68  	uint32Type  = reflect.TypeOf(uint32(0))
  69  	uint64Type  = reflect.TypeOf(uint64(0))
  70  	float32Type = reflect.TypeOf(float32(0))
  71  	float64Type = reflect.TypeOf(float64(0))
  72  	stringType  = reflect.TypeOf(string(""))
  73  	bytesType   = reflect.TypeOf([]byte(nil))
  74  	byteType    = reflect.TypeOf(byte(0))
  75  )
  76  
  77  var (
  78  	boolZero    = protoreflect.ValueOfBool(false)
  79  	int32Zero   = protoreflect.ValueOfInt32(0)
  80  	int64Zero   = protoreflect.ValueOfInt64(0)
  81  	uint32Zero  = protoreflect.ValueOfUint32(0)
  82  	uint64Zero  = protoreflect.ValueOfUint64(0)
  83  	float32Zero = protoreflect.ValueOfFloat32(0)
  84  	float64Zero = protoreflect.ValueOfFloat64(0)
  85  	stringZero  = protoreflect.ValueOfString("")
  86  	bytesZero   = protoreflect.ValueOfBytes(nil)
  87  )
  88  
  89  func newSingularConverter(t reflect.Type, fd protoreflect.FieldDescriptor) Converter {
  90  	defVal := func(fd protoreflect.FieldDescriptor, zero protoreflect.Value) protoreflect.Value {
  91  		if fd.Cardinality() == protoreflect.Repeated {
  92  			// Default isn't defined for repeated fields.
  93  			return zero
  94  		}
  95  		return fd.Default()
  96  	}
  97  	switch fd.Kind() {
  98  	case protoreflect.BoolKind:
  99  		if t.Kind() == reflect.Bool {
 100  			return &boolConverter{t, defVal(fd, boolZero)}
 101  		}
 102  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
 103  		if t.Kind() == reflect.Int32 {
 104  			return &int32Converter{t, defVal(fd, int32Zero)}
 105  		}
 106  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
 107  		if t.Kind() == reflect.Int64 {
 108  			return &int64Converter{t, defVal(fd, int64Zero)}
 109  		}
 110  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
 111  		if t.Kind() == reflect.Uint32 {
 112  			return &uint32Converter{t, defVal(fd, uint32Zero)}
 113  		}
 114  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
 115  		if t.Kind() == reflect.Uint64 {
 116  			return &uint64Converter{t, defVal(fd, uint64Zero)}
 117  		}
 118  	case protoreflect.FloatKind:
 119  		if t.Kind() == reflect.Float32 {
 120  			return &float32Converter{t, defVal(fd, float32Zero)}
 121  		}
 122  	case protoreflect.DoubleKind:
 123  		if t.Kind() == reflect.Float64 {
 124  			return &float64Converter{t, defVal(fd, float64Zero)}
 125  		}
 126  	case protoreflect.StringKind:
 127  		if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
 128  			return &stringConverter{t, defVal(fd, stringZero)}
 129  		}
 130  	case protoreflect.BytesKind:
 131  		if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
 132  			return &bytesConverter{t, defVal(fd, bytesZero)}
 133  		}
 134  	case protoreflect.EnumKind:
 135  		// Handle enums, which must be a named int32 type.
 136  		if t.Kind() == reflect.Int32 {
 137  			return newEnumConverter(t, fd)
 138  		}
 139  	case protoreflect.MessageKind, protoreflect.GroupKind:
 140  		return newMessageConverter(t)
 141  	}
 142  	panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
 143  }
 144  
 145  type boolConverter struct {
 146  	goType reflect.Type
 147  	def    protoreflect.Value
 148  }
 149  
 150  func (c *boolConverter) PBValueOf(v reflect.Value) protoreflect.Value {
 151  	if v.Type() != c.goType {
 152  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 153  	}
 154  	return protoreflect.ValueOfBool(v.Bool())
 155  }
 156  func (c *boolConverter) GoValueOf(v protoreflect.Value) reflect.Value {
 157  	return reflect.ValueOf(v.Bool()).Convert(c.goType)
 158  }
 159  func (c *boolConverter) IsValidPB(v protoreflect.Value) bool {
 160  	_, ok := v.Interface().(bool)
 161  	return ok
 162  }
 163  func (c *boolConverter) IsValidGo(v reflect.Value) bool {
 164  	return v.IsValid() && v.Type() == c.goType
 165  }
 166  func (c *boolConverter) New() protoreflect.Value  { return c.def }
 167  func (c *boolConverter) Zero() protoreflect.Value { return c.def }
 168  
 169  type int32Converter struct {
 170  	goType reflect.Type
 171  	def    protoreflect.Value
 172  }
 173  
 174  func (c *int32Converter) PBValueOf(v reflect.Value) protoreflect.Value {
 175  	if v.Type() != c.goType {
 176  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 177  	}
 178  	return protoreflect.ValueOfInt32(int32(v.Int()))
 179  }
 180  func (c *int32Converter) GoValueOf(v protoreflect.Value) reflect.Value {
 181  	return reflect.ValueOf(int32(v.Int())).Convert(c.goType)
 182  }
 183  func (c *int32Converter) IsValidPB(v protoreflect.Value) bool {
 184  	_, ok := v.Interface().(int32)
 185  	return ok
 186  }
 187  func (c *int32Converter) IsValidGo(v reflect.Value) bool {
 188  	return v.IsValid() && v.Type() == c.goType
 189  }
 190  func (c *int32Converter) New() protoreflect.Value  { return c.def }
 191  func (c *int32Converter) Zero() protoreflect.Value { return c.def }
 192  
 193  type int64Converter struct {
 194  	goType reflect.Type
 195  	def    protoreflect.Value
 196  }
 197  
 198  func (c *int64Converter) PBValueOf(v reflect.Value) protoreflect.Value {
 199  	if v.Type() != c.goType {
 200  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 201  	}
 202  	return protoreflect.ValueOfInt64(int64(v.Int()))
 203  }
 204  func (c *int64Converter) GoValueOf(v protoreflect.Value) reflect.Value {
 205  	return reflect.ValueOf(int64(v.Int())).Convert(c.goType)
 206  }
 207  func (c *int64Converter) IsValidPB(v protoreflect.Value) bool {
 208  	_, ok := v.Interface().(int64)
 209  	return ok
 210  }
 211  func (c *int64Converter) IsValidGo(v reflect.Value) bool {
 212  	return v.IsValid() && v.Type() == c.goType
 213  }
 214  func (c *int64Converter) New() protoreflect.Value  { return c.def }
 215  func (c *int64Converter) Zero() protoreflect.Value { return c.def }
 216  
 217  type uint32Converter struct {
 218  	goType reflect.Type
 219  	def    protoreflect.Value
 220  }
 221  
 222  func (c *uint32Converter) PBValueOf(v reflect.Value) protoreflect.Value {
 223  	if v.Type() != c.goType {
 224  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 225  	}
 226  	return protoreflect.ValueOfUint32(uint32(v.Uint()))
 227  }
 228  func (c *uint32Converter) GoValueOf(v protoreflect.Value) reflect.Value {
 229  	return reflect.ValueOf(uint32(v.Uint())).Convert(c.goType)
 230  }
 231  func (c *uint32Converter) IsValidPB(v protoreflect.Value) bool {
 232  	_, ok := v.Interface().(uint32)
 233  	return ok
 234  }
 235  func (c *uint32Converter) IsValidGo(v reflect.Value) bool {
 236  	return v.IsValid() && v.Type() == c.goType
 237  }
 238  func (c *uint32Converter) New() protoreflect.Value  { return c.def }
 239  func (c *uint32Converter) Zero() protoreflect.Value { return c.def }
 240  
 241  type uint64Converter struct {
 242  	goType reflect.Type
 243  	def    protoreflect.Value
 244  }
 245  
 246  func (c *uint64Converter) PBValueOf(v reflect.Value) protoreflect.Value {
 247  	if v.Type() != c.goType {
 248  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 249  	}
 250  	return protoreflect.ValueOfUint64(uint64(v.Uint()))
 251  }
 252  func (c *uint64Converter) GoValueOf(v protoreflect.Value) reflect.Value {
 253  	return reflect.ValueOf(uint64(v.Uint())).Convert(c.goType)
 254  }
 255  func (c *uint64Converter) IsValidPB(v protoreflect.Value) bool {
 256  	_, ok := v.Interface().(uint64)
 257  	return ok
 258  }
 259  func (c *uint64Converter) IsValidGo(v reflect.Value) bool {
 260  	return v.IsValid() && v.Type() == c.goType
 261  }
 262  func (c *uint64Converter) New() protoreflect.Value  { return c.def }
 263  func (c *uint64Converter) Zero() protoreflect.Value { return c.def }
 264  
 265  type float32Converter struct {
 266  	goType reflect.Type
 267  	def    protoreflect.Value
 268  }
 269  
 270  func (c *float32Converter) PBValueOf(v reflect.Value) protoreflect.Value {
 271  	if v.Type() != c.goType {
 272  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 273  	}
 274  	return protoreflect.ValueOfFloat32(float32(v.Float()))
 275  }
 276  func (c *float32Converter) GoValueOf(v protoreflect.Value) reflect.Value {
 277  	return reflect.ValueOf(float32(v.Float())).Convert(c.goType)
 278  }
 279  func (c *float32Converter) IsValidPB(v protoreflect.Value) bool {
 280  	_, ok := v.Interface().(float32)
 281  	return ok
 282  }
 283  func (c *float32Converter) IsValidGo(v reflect.Value) bool {
 284  	return v.IsValid() && v.Type() == c.goType
 285  }
 286  func (c *float32Converter) New() protoreflect.Value  { return c.def }
 287  func (c *float32Converter) Zero() protoreflect.Value { return c.def }
 288  
 289  type float64Converter struct {
 290  	goType reflect.Type
 291  	def    protoreflect.Value
 292  }
 293  
 294  func (c *float64Converter) PBValueOf(v reflect.Value) protoreflect.Value {
 295  	if v.Type() != c.goType {
 296  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 297  	}
 298  	return protoreflect.ValueOfFloat64(float64(v.Float()))
 299  }
 300  func (c *float64Converter) GoValueOf(v protoreflect.Value) reflect.Value {
 301  	return reflect.ValueOf(float64(v.Float())).Convert(c.goType)
 302  }
 303  func (c *float64Converter) IsValidPB(v protoreflect.Value) bool {
 304  	_, ok := v.Interface().(float64)
 305  	return ok
 306  }
 307  func (c *float64Converter) IsValidGo(v reflect.Value) bool {
 308  	return v.IsValid() && v.Type() == c.goType
 309  }
 310  func (c *float64Converter) New() protoreflect.Value  { return c.def }
 311  func (c *float64Converter) Zero() protoreflect.Value { return c.def }
 312  
 313  type stringConverter struct {
 314  	goType reflect.Type
 315  	def    protoreflect.Value
 316  }
 317  
 318  func (c *stringConverter) PBValueOf(v reflect.Value) protoreflect.Value {
 319  	if v.Type() != c.goType {
 320  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 321  	}
 322  	return protoreflect.ValueOfString(v.Convert(stringType).String())
 323  }
 324  func (c *stringConverter) GoValueOf(v protoreflect.Value) reflect.Value {
 325  	// protoreflect.Value.String never panics, so we go through an interface
 326  	// conversion here to check the type.
 327  	s := v.Interface().(string)
 328  	if c.goType.Kind() == reflect.Slice && s == "" {
 329  		return reflect.Zero(c.goType) // ensure empty string is []byte(nil)
 330  	}
 331  	return reflect.ValueOf(s).Convert(c.goType)
 332  }
 333  func (c *stringConverter) IsValidPB(v protoreflect.Value) bool {
 334  	_, ok := v.Interface().(string)
 335  	return ok
 336  }
 337  func (c *stringConverter) IsValidGo(v reflect.Value) bool {
 338  	return v.IsValid() && v.Type() == c.goType
 339  }
 340  func (c *stringConverter) New() protoreflect.Value  { return c.def }
 341  func (c *stringConverter) Zero() protoreflect.Value { return c.def }
 342  
 343  type bytesConverter struct {
 344  	goType reflect.Type
 345  	def    protoreflect.Value
 346  }
 347  
 348  func (c *bytesConverter) PBValueOf(v reflect.Value) protoreflect.Value {
 349  	if v.Type() != c.goType {
 350  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 351  	}
 352  	if c.goType.Kind() == reflect.String && v.Len() == 0 {
 353  		return protoreflect.ValueOfBytes(nil) // ensure empty string is []byte(nil)
 354  	}
 355  	return protoreflect.ValueOfBytes(v.Convert(bytesType).Bytes())
 356  }
 357  func (c *bytesConverter) GoValueOf(v protoreflect.Value) reflect.Value {
 358  	return reflect.ValueOf(v.Bytes()).Convert(c.goType)
 359  }
 360  func (c *bytesConverter) IsValidPB(v protoreflect.Value) bool {
 361  	_, ok := v.Interface().([]byte)
 362  	return ok
 363  }
 364  func (c *bytesConverter) IsValidGo(v reflect.Value) bool {
 365  	return v.IsValid() && v.Type() == c.goType
 366  }
 367  func (c *bytesConverter) New() protoreflect.Value  { return c.def }
 368  func (c *bytesConverter) Zero() protoreflect.Value { return c.def }
 369  
 370  type enumConverter struct {
 371  	goType reflect.Type
 372  	def    protoreflect.Value
 373  }
 374  
 375  func newEnumConverter(goType reflect.Type, fd protoreflect.FieldDescriptor) Converter {
 376  	var def protoreflect.Value
 377  	if fd.Cardinality() == protoreflect.Repeated {
 378  		def = protoreflect.ValueOfEnum(fd.Enum().Values().Get(0).Number())
 379  	} else {
 380  		def = fd.Default()
 381  	}
 382  	return &enumConverter{goType, def}
 383  }
 384  
 385  func (c *enumConverter) PBValueOf(v reflect.Value) protoreflect.Value {
 386  	if v.Type() != c.goType {
 387  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 388  	}
 389  	return protoreflect.ValueOfEnum(protoreflect.EnumNumber(v.Int()))
 390  }
 391  
 392  func (c *enumConverter) GoValueOf(v protoreflect.Value) reflect.Value {
 393  	return reflect.ValueOf(v.Enum()).Convert(c.goType)
 394  }
 395  
 396  func (c *enumConverter) IsValidPB(v protoreflect.Value) bool {
 397  	_, ok := v.Interface().(protoreflect.EnumNumber)
 398  	return ok
 399  }
 400  
 401  func (c *enumConverter) IsValidGo(v reflect.Value) bool {
 402  	return v.IsValid() && v.Type() == c.goType
 403  }
 404  
 405  func (c *enumConverter) New() protoreflect.Value {
 406  	return c.def
 407  }
 408  
 409  func (c *enumConverter) Zero() protoreflect.Value {
 410  	return c.def
 411  }
 412  
 413  type messageConverter struct {
 414  	goType reflect.Type
 415  }
 416  
 417  func newMessageConverter(goType reflect.Type) Converter {
 418  	return &messageConverter{goType}
 419  }
 420  
 421  func (c *messageConverter) PBValueOf(v reflect.Value) protoreflect.Value {
 422  	if v.Type() != c.goType {
 423  		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 424  	}
 425  	if c.isNonPointer() {
 426  		if v.CanAddr() {
 427  			v = v.Addr() // T => *T
 428  		} else {
 429  			v = reflect.Zero(reflect.PtrTo(v.Type()))
 430  		}
 431  	}
 432  	if m, ok := v.Interface().(protoreflect.ProtoMessage); ok {
 433  		return protoreflect.ValueOfMessage(m.ProtoReflect())
 434  	}
 435  	return protoreflect.ValueOfMessage(legacyWrapMessage(v))
 436  }
 437  
 438  func (c *messageConverter) GoValueOf(v protoreflect.Value) reflect.Value {
 439  	m := v.Message()
 440  	var rv reflect.Value
 441  	if u, ok := m.(unwrapper); ok {
 442  		rv = reflect.ValueOf(u.protoUnwrap())
 443  	} else {
 444  		rv = reflect.ValueOf(m.Interface())
 445  	}
 446  	if c.isNonPointer() {
 447  		if rv.Type() != reflect.PtrTo(c.goType) {
 448  			panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), reflect.PtrTo(c.goType)))
 449  		}
 450  		if !rv.IsNil() {
 451  			rv = rv.Elem() // *T => T
 452  		} else {
 453  			rv = reflect.Zero(rv.Type().Elem())
 454  		}
 455  	}
 456  	if rv.Type() != c.goType {
 457  		panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), c.goType))
 458  	}
 459  	return rv
 460  }
 461  
 462  func (c *messageConverter) IsValidPB(v protoreflect.Value) bool {
 463  	m := v.Message()
 464  	var rv reflect.Value
 465  	if u, ok := m.(unwrapper); ok {
 466  		rv = reflect.ValueOf(u.protoUnwrap())
 467  	} else {
 468  		rv = reflect.ValueOf(m.Interface())
 469  	}
 470  	if c.isNonPointer() {
 471  		return rv.Type() == reflect.PtrTo(c.goType)
 472  	}
 473  	return rv.Type() == c.goType
 474  }
 475  
 476  func (c *messageConverter) IsValidGo(v reflect.Value) bool {
 477  	return v.IsValid() && v.Type() == c.goType
 478  }
 479  
 480  func (c *messageConverter) New() protoreflect.Value {
 481  	if c.isNonPointer() {
 482  		return c.PBValueOf(reflect.New(c.goType).Elem())
 483  	}
 484  	return c.PBValueOf(reflect.New(c.goType.Elem()))
 485  }
 486  
 487  func (c *messageConverter) Zero() protoreflect.Value {
 488  	return c.PBValueOf(reflect.Zero(c.goType))
 489  }
 490  
 491  // isNonPointer reports whether the type is a non-pointer type.
 492  // This never occurs for generated message types.
 493  func (c *messageConverter) isNonPointer() bool {
 494  	return c.goType.Kind() != reflect.Ptr
 495  }
 496