extensions.go raw

   1  // Copyright 2010 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proto
   6  
   7  import (
   8  	"errors"
   9  	"fmt"
  10  	"reflect"
  11  
  12  	"google.golang.org/protobuf/encoding/protowire"
  13  	"google.golang.org/protobuf/proto"
  14  	"google.golang.org/protobuf/reflect/protoreflect"
  15  	"google.golang.org/protobuf/reflect/protoregistry"
  16  	"google.golang.org/protobuf/runtime/protoiface"
  17  	"google.golang.org/protobuf/runtime/protoimpl"
  18  )
  19  
  20  type (
  21  	// ExtensionDesc represents an extension descriptor and
  22  	// is used to interact with an extension field in a message.
  23  	//
  24  	// Variables of this type are generated in code by protoc-gen-go.
  25  	ExtensionDesc = protoimpl.ExtensionInfo
  26  
  27  	// ExtensionRange represents a range of message extensions.
  28  	// Used in code generated by protoc-gen-go.
  29  	ExtensionRange = protoiface.ExtensionRangeV1
  30  
  31  	// Deprecated: Do not use; this is an internal type.
  32  	Extension = protoimpl.ExtensionFieldV1
  33  
  34  	// Deprecated: Do not use; this is an internal type.
  35  	XXX_InternalExtensions = protoimpl.ExtensionFields
  36  )
  37  
  38  // ErrMissingExtension reports whether the extension was not present.
  39  var ErrMissingExtension = errors.New("proto: missing extension")
  40  
  41  var errNotExtendable = errors.New("proto: not an extendable proto.Message")
  42  
  43  // HasExtension reports whether the extension field is present in m
  44  // either as an explicitly populated field or as an unknown field.
  45  func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
  46  	mr := MessageReflect(m)
  47  	if mr == nil || !mr.IsValid() {
  48  		return false
  49  	}
  50  
  51  	// Check whether any populated known field matches the field number.
  52  	xtd := xt.TypeDescriptor()
  53  	if isValidExtension(mr.Descriptor(), xtd) {
  54  		has = mr.Has(xtd)
  55  	} else {
  56  		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
  57  			has = int32(fd.Number()) == xt.Field
  58  			return !has
  59  		})
  60  	}
  61  
  62  	// Check whether any unknown field matches the field number.
  63  	for b := mr.GetUnknown(); !has && len(b) > 0; {
  64  		num, _, n := protowire.ConsumeField(b)
  65  		has = int32(num) == xt.Field
  66  		b = b[n:]
  67  	}
  68  	return has
  69  }
  70  
  71  // ClearExtension removes the extension field from m
  72  // either as an explicitly populated field or as an unknown field.
  73  func ClearExtension(m Message, xt *ExtensionDesc) {
  74  	mr := MessageReflect(m)
  75  	if mr == nil || !mr.IsValid() {
  76  		return
  77  	}
  78  
  79  	xtd := xt.TypeDescriptor()
  80  	if isValidExtension(mr.Descriptor(), xtd) {
  81  		mr.Clear(xtd)
  82  	} else {
  83  		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
  84  			if int32(fd.Number()) == xt.Field {
  85  				mr.Clear(fd)
  86  				return false
  87  			}
  88  			return true
  89  		})
  90  	}
  91  	clearUnknown(mr, fieldNum(xt.Field))
  92  }
  93  
  94  // ClearAllExtensions clears all extensions from m.
  95  // This includes populated fields and unknown fields in the extension range.
  96  func ClearAllExtensions(m Message) {
  97  	mr := MessageReflect(m)
  98  	if mr == nil || !mr.IsValid() {
  99  		return
 100  	}
 101  
 102  	mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
 103  		if fd.IsExtension() {
 104  			mr.Clear(fd)
 105  		}
 106  		return true
 107  	})
 108  	clearUnknown(mr, mr.Descriptor().ExtensionRanges())
 109  }
 110  
 111  // GetExtension retrieves a proto2 extended field from m.
 112  //
 113  // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
 114  // then GetExtension parses the encoded field and returns a Go value of the specified type.
 115  // If the field is not present, then the default value is returned (if one is specified),
 116  // otherwise ErrMissingExtension is reported.
 117  //
 118  // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
 119  // then GetExtension returns the raw encoded bytes for the extension field.
 120  func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
 121  	mr := MessageReflect(m)
 122  	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
 123  		return nil, errNotExtendable
 124  	}
 125  
 126  	// Retrieve the unknown fields for this extension field.
 127  	var bo protoreflect.RawFields
 128  	for bi := mr.GetUnknown(); len(bi) > 0; {
 129  		num, _, n := protowire.ConsumeField(bi)
 130  		if int32(num) == xt.Field {
 131  			bo = append(bo, bi[:n]...)
 132  		}
 133  		bi = bi[n:]
 134  	}
 135  
 136  	// For type incomplete descriptors, only retrieve the unknown fields.
 137  	if xt.ExtensionType == nil {
 138  		return []byte(bo), nil
 139  	}
 140  
 141  	// If the extension field only exists as unknown fields, unmarshal it.
 142  	// This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
 143  	xtd := xt.TypeDescriptor()
 144  	if !isValidExtension(mr.Descriptor(), xtd) {
 145  		return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
 146  	}
 147  	if !mr.Has(xtd) && len(bo) > 0 {
 148  		m2 := mr.New()
 149  		if err := (proto.UnmarshalOptions{
 150  			Resolver: extensionResolver{xt},
 151  		}.Unmarshal(bo, m2.Interface())); err != nil {
 152  			return nil, err
 153  		}
 154  		if m2.Has(xtd) {
 155  			mr.Set(xtd, m2.Get(xtd))
 156  			clearUnknown(mr, fieldNum(xt.Field))
 157  		}
 158  	}
 159  
 160  	// Check whether the message has the extension field set or a default.
 161  	var pv protoreflect.Value
 162  	switch {
 163  	case mr.Has(xtd):
 164  		pv = mr.Get(xtd)
 165  	case xtd.HasDefault():
 166  		pv = xtd.Default()
 167  	default:
 168  		return nil, ErrMissingExtension
 169  	}
 170  
 171  	v := xt.InterfaceOf(pv)
 172  	rv := reflect.ValueOf(v)
 173  	if isScalarKind(rv.Kind()) {
 174  		rv2 := reflect.New(rv.Type())
 175  		rv2.Elem().Set(rv)
 176  		v = rv2.Interface()
 177  	}
 178  	return v, nil
 179  }
 180  
 181  // extensionResolver is a custom extension resolver that stores a single
 182  // extension type that takes precedence over the global registry.
 183  type extensionResolver struct{ xt protoreflect.ExtensionType }
 184  
 185  func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
 186  	if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
 187  		return r.xt, nil
 188  	}
 189  	return protoregistry.GlobalTypes.FindExtensionByName(field)
 190  }
 191  
 192  func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
 193  	if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
 194  		return r.xt, nil
 195  	}
 196  	return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
 197  }
 198  
 199  // GetExtensions returns a list of the extensions values present in m,
 200  // corresponding with the provided list of extension descriptors, xts.
 201  // If an extension is missing in m, the corresponding value is nil.
 202  func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
 203  	mr := MessageReflect(m)
 204  	if mr == nil || !mr.IsValid() {
 205  		return nil, errNotExtendable
 206  	}
 207  
 208  	vs := make([]interface{}, len(xts))
 209  	for i, xt := range xts {
 210  		v, err := GetExtension(m, xt)
 211  		if err != nil {
 212  			if err == ErrMissingExtension {
 213  				continue
 214  			}
 215  			return vs, err
 216  		}
 217  		vs[i] = v
 218  	}
 219  	return vs, nil
 220  }
 221  
 222  // SetExtension sets an extension field in m to the provided value.
 223  func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
 224  	mr := MessageReflect(m)
 225  	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
 226  		return errNotExtendable
 227  	}
 228  
 229  	rv := reflect.ValueOf(v)
 230  	if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
 231  		return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
 232  	}
 233  	if rv.Kind() == reflect.Ptr {
 234  		if rv.IsNil() {
 235  			return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
 236  		}
 237  		if isScalarKind(rv.Elem().Kind()) {
 238  			v = rv.Elem().Interface()
 239  		}
 240  	}
 241  
 242  	xtd := xt.TypeDescriptor()
 243  	if !isValidExtension(mr.Descriptor(), xtd) {
 244  		return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
 245  	}
 246  	mr.Set(xtd, xt.ValueOf(v))
 247  	clearUnknown(mr, fieldNum(xt.Field))
 248  	return nil
 249  }
 250  
 251  // SetRawExtension inserts b into the unknown fields of m.
 252  //
 253  // Deprecated: Use Message.ProtoReflect.SetUnknown instead.
 254  func SetRawExtension(m Message, fnum int32, b []byte) {
 255  	mr := MessageReflect(m)
 256  	if mr == nil || !mr.IsValid() {
 257  		return
 258  	}
 259  
 260  	// Verify that the raw field is valid.
 261  	for b0 := b; len(b0) > 0; {
 262  		num, _, n := protowire.ConsumeField(b0)
 263  		if int32(num) != fnum {
 264  			panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
 265  		}
 266  		b0 = b0[n:]
 267  	}
 268  
 269  	ClearExtension(m, &ExtensionDesc{Field: fnum})
 270  	mr.SetUnknown(append(mr.GetUnknown(), b...))
 271  }
 272  
 273  // ExtensionDescs returns a list of extension descriptors found in m,
 274  // containing descriptors for both populated extension fields in m and
 275  // also unknown fields of m that are in the extension range.
 276  // For the later case, an type incomplete descriptor is provided where only
 277  // the ExtensionDesc.Field field is populated.
 278  // The order of the extension descriptors is undefined.
 279  func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
 280  	mr := MessageReflect(m)
 281  	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
 282  		return nil, errNotExtendable
 283  	}
 284  
 285  	// Collect a set of known extension descriptors.
 286  	extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
 287  	mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 288  		if fd.IsExtension() {
 289  			xt := fd.(protoreflect.ExtensionTypeDescriptor)
 290  			if xd, ok := xt.Type().(*ExtensionDesc); ok {
 291  				extDescs[fd.Number()] = xd
 292  			}
 293  		}
 294  		return true
 295  	})
 296  
 297  	// Collect a set of unknown extension descriptors.
 298  	extRanges := mr.Descriptor().ExtensionRanges()
 299  	for b := mr.GetUnknown(); len(b) > 0; {
 300  		num, _, n := protowire.ConsumeField(b)
 301  		if extRanges.Has(num) && extDescs[num] == nil {
 302  			extDescs[num] = nil
 303  		}
 304  		b = b[n:]
 305  	}
 306  
 307  	// Transpose the set of descriptors into a list.
 308  	var xts []*ExtensionDesc
 309  	for num, xt := range extDescs {
 310  		if xt == nil {
 311  			xt = &ExtensionDesc{Field: int32(num)}
 312  		}
 313  		xts = append(xts, xt)
 314  	}
 315  	return xts, nil
 316  }
 317  
 318  // isValidExtension reports whether xtd is a valid extension descriptor for md.
 319  func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
 320  	return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
 321  }
 322  
 323  // isScalarKind reports whether k is a protobuf scalar kind (except bytes).
 324  // This function exists for historical reasons since the representation of
 325  // scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
 326  func isScalarKind(k reflect.Kind) bool {
 327  	switch k {
 328  	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
 329  		return true
 330  	default:
 331  		return false
 332  	}
 333  }
 334  
 335  // clearUnknown removes unknown fields from m where remover.Has reports true.
 336  func clearUnknown(m protoreflect.Message, remover interface {
 337  	Has(protoreflect.FieldNumber) bool
 338  }) {
 339  	var bo protoreflect.RawFields
 340  	for bi := m.GetUnknown(); len(bi) > 0; {
 341  		num, _, n := protowire.ConsumeField(bi)
 342  		if !remover.Has(num) {
 343  			bo = append(bo, bi[:n]...)
 344  		}
 345  		bi = bi[n:]
 346  	}
 347  	if bi := m.GetUnknown(); len(bi) != len(bo) {
 348  		m.SetUnknown(bo)
 349  	}
 350  }
 351  
 352  type fieldNum protoreflect.FieldNumber
 353  
 354  func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
 355  	return protoreflect.FieldNumber(n1) == n2
 356  }
 357