extension.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proto
   6  
   7  import (
   8  	"google.golang.org/protobuf/reflect/protoreflect"
   9  )
  10  
  11  // HasExtension reports whether an extension field is populated.
  12  // It returns false if m is invalid or if xt does not extend m.
  13  func HasExtension(m Message, xt protoreflect.ExtensionType) bool {
  14  	// Treat nil message interface or descriptor as an empty message; no populated
  15  	// fields.
  16  	if m == nil || xt == nil {
  17  		return false
  18  	}
  19  
  20  	// As a special-case, we reports invalid or mismatching descriptors
  21  	// as always not being populated (since they aren't).
  22  	mr := m.ProtoReflect()
  23  	xd := xt.TypeDescriptor()
  24  	if mr.Descriptor() != xd.ContainingMessage() {
  25  		return false
  26  	}
  27  
  28  	return mr.Has(xd)
  29  }
  30  
  31  // ClearExtension clears an extension field such that subsequent
  32  // [HasExtension] calls return false.
  33  // It panics if m is invalid or if xt does not extend m.
  34  func ClearExtension(m Message, xt protoreflect.ExtensionType) {
  35  	m.ProtoReflect().Clear(xt.TypeDescriptor())
  36  }
  37  
  38  // GetExtension retrieves the value for an extension field.
  39  // If the field is unpopulated, it returns the default value for
  40  // scalars and an immutable, empty value for lists or messages.
  41  // It panics if xt does not extend m.
  42  //
  43  // The type of the value is dependent on the field type of the extension.
  44  // For extensions generated by protoc-gen-go, the Go type is as follows:
  45  //
  46  //	╔═══════════════════╤═════════════════════════╗
  47  //	║ Go type           │ Protobuf kind           ║
  48  //	╠═══════════════════╪═════════════════════════╣
  49  //	║ bool              │ bool                    ║
  50  //	║ int32             │ int32, sint32, sfixed32 ║
  51  //	║ int64             │ int64, sint64, sfixed64 ║
  52  //	║ uint32            │ uint32, fixed32         ║
  53  //	║ uint64            │ uint64, fixed64         ║
  54  //	║ float32           │ float                   ║
  55  //	║ float64           │ double                  ║
  56  //	║ string            │ string                  ║
  57  //	║ []byte            │ bytes                   ║
  58  //	║ protoreflect.Enum │ enum                    ║
  59  //	║ proto.Message     │ message, group          ║
  60  //	╚═══════════════════╧═════════════════════════╝
  61  //
  62  // The protoreflect.Enum and proto.Message types are the concrete Go type
  63  // associated with the named enum or message. Repeated fields are represented
  64  // using a Go slice of the base element type.
  65  //
  66  // If a generated extension descriptor variable is directly passed to
  67  // GetExtension, then the call should be followed immediately by a
  68  // type assertion to the expected output value. For example:
  69  //
  70  //	mm := proto.GetExtension(m, foopb.E_MyExtension).(*foopb.MyMessage)
  71  //
  72  // This pattern enables static analysis tools to verify that the asserted type
  73  // matches the Go type associated with the extension field and
  74  // also enables a possible future migration to a type-safe extension API.
  75  //
  76  // Since singular messages are the most common extension type, the pattern of
  77  // calling HasExtension followed by GetExtension may be simplified to:
  78  //
  79  //	if mm := proto.GetExtension(m, foopb.E_MyExtension).(*foopb.MyMessage); mm != nil {
  80  //	    ... // make use of mm
  81  //	}
  82  //
  83  // The mm variable is non-nil if and only if HasExtension reports true.
  84  func GetExtension(m Message, xt protoreflect.ExtensionType) any {
  85  	// Treat nil message interface as an empty message; return the default.
  86  	if m == nil {
  87  		return xt.InterfaceOf(xt.Zero())
  88  	}
  89  
  90  	return xt.InterfaceOf(m.ProtoReflect().Get(xt.TypeDescriptor()))
  91  }
  92  
  93  // SetExtension stores the value of an extension field.
  94  // It panics if m is invalid, xt does not extend m, or if type of v
  95  // is invalid for the specified extension field.
  96  //
  97  // The type of the value is dependent on the field type of the extension.
  98  // For extensions generated by protoc-gen-go, the Go type is as follows:
  99  //
 100  //	╔═══════════════════╤═════════════════════════╗
 101  //	║ Go type           │ Protobuf kind           ║
 102  //	╠═══════════════════╪═════════════════════════╣
 103  //	║ bool              │ bool                    ║
 104  //	║ int32             │ int32, sint32, sfixed32 ║
 105  //	║ int64             │ int64, sint64, sfixed64 ║
 106  //	║ uint32            │ uint32, fixed32         ║
 107  //	║ uint64            │ uint64, fixed64         ║
 108  //	║ float32           │ float                   ║
 109  //	║ float64           │ double                  ║
 110  //	║ string            │ string                  ║
 111  //	║ []byte            │ bytes                   ║
 112  //	║ protoreflect.Enum │ enum                    ║
 113  //	║ proto.Message     │ message, group          ║
 114  //	╚═══════════════════╧═════════════════════════╝
 115  //
 116  // The protoreflect.Enum and proto.Message types are the concrete Go type
 117  // associated with the named enum or message. Repeated fields are represented
 118  // using a Go slice of the base element type.
 119  //
 120  // If a generated extension descriptor variable is directly passed to
 121  // SetExtension (e.g., foopb.E_MyExtension), then the value should be a
 122  // concrete type that matches the expected Go type for the extension descriptor
 123  // so that static analysis tools can verify type correctness.
 124  // This also enables a possible future migration to a type-safe extension API.
 125  func SetExtension(m Message, xt protoreflect.ExtensionType, v any) {
 126  	xd := xt.TypeDescriptor()
 127  	pv := xt.ValueOf(v)
 128  
 129  	// Specially treat an invalid list, map, or message as clear.
 130  	isValid := true
 131  	switch {
 132  	case xd.IsList():
 133  		isValid = pv.List().IsValid()
 134  	case xd.IsMap():
 135  		isValid = pv.Map().IsValid()
 136  	case xd.Message() != nil:
 137  		isValid = pv.Message().IsValid()
 138  	}
 139  	if !isValid {
 140  		m.ProtoReflect().Clear(xd)
 141  		return
 142  	}
 143  
 144  	m.ProtoReflect().Set(xd, pv)
 145  }
 146  
 147  // RangeExtensions iterates over every populated extension field in m in an
 148  // undefined order, calling f for each extension type and value encountered.
 149  // It returns immediately if f returns false.
 150  // While iterating, mutating operations may only be performed
 151  // on the current extension field.
 152  func RangeExtensions(m Message, f func(protoreflect.ExtensionType, any) bool) {
 153  	// Treat nil message interface as an empty message; nothing to range over.
 154  	if m == nil {
 155  		return
 156  	}
 157  
 158  	m.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 159  		if fd.IsExtension() {
 160  			xt := fd.(protoreflect.ExtensionTypeDescriptor).Type()
 161  			vi := xt.InterfaceOf(v)
 162  			return f(xt, vi)
 163  		}
 164  		return true
 165  	})
 166  }
 167