decode.go raw

   1  // Copyright 2018 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proto
   6  
   7  import (
   8  	"google.golang.org/protobuf/encoding/protowire"
   9  	"google.golang.org/protobuf/internal/encoding/messageset"
  10  	"google.golang.org/protobuf/internal/errors"
  11  	"google.golang.org/protobuf/internal/genid"
  12  	"google.golang.org/protobuf/internal/pragma"
  13  	"google.golang.org/protobuf/reflect/protoreflect"
  14  	"google.golang.org/protobuf/reflect/protoregistry"
  15  	"google.golang.org/protobuf/runtime/protoiface"
  16  )
  17  
  18  // UnmarshalOptions configures the unmarshaler.
  19  //
  20  // Example usage:
  21  //
  22  //	err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
  23  type UnmarshalOptions struct {
  24  	pragma.NoUnkeyedLiterals
  25  
  26  	// Merge merges the input into the destination message.
  27  	// The default behavior is to always reset the message before unmarshaling,
  28  	// unless Merge is specified.
  29  	Merge bool
  30  
  31  	// AllowPartial accepts input for messages that will result in missing
  32  	// required fields. If AllowPartial is false (the default), Unmarshal will
  33  	// return an error if there are any missing required fields.
  34  	AllowPartial bool
  35  
  36  	// If DiscardUnknown is set, unknown fields are ignored.
  37  	DiscardUnknown bool
  38  
  39  	// Resolver is used for looking up types when unmarshaling extension fields.
  40  	// If nil, this defaults to using protoregistry.GlobalTypes.
  41  	Resolver interface {
  42  		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  43  		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  44  	}
  45  
  46  	// RecursionLimit limits how deeply messages may be nested.
  47  	// If zero, a default limit is applied.
  48  	RecursionLimit int
  49  
  50  	//
  51  	// NoLazyDecoding turns off lazy decoding, which otherwise is enabled by
  52  	// default. Lazy decoding only affects submessages (annotated with [lazy =
  53  	// true] in the .proto file) within messages that use the Opaque API.
  54  	NoLazyDecoding bool
  55  }
  56  
  57  // Unmarshal parses the wire-format message in b and places the result in m.
  58  // The provided message must be mutable (e.g., a non-nil pointer to a message).
  59  //
  60  // See the [UnmarshalOptions] type if you need more control.
  61  func Unmarshal(b []byte, m Message) error {
  62  	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
  63  	return err
  64  }
  65  
  66  // Unmarshal parses the wire-format message in b and places the result in m.
  67  // The provided message must be mutable (e.g., a non-nil pointer to a message).
  68  func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
  69  	if o.RecursionLimit == 0 {
  70  		o.RecursionLimit = protowire.DefaultRecursionLimit
  71  	}
  72  	_, err := o.unmarshal(b, m.ProtoReflect())
  73  	return err
  74  }
  75  
  76  // UnmarshalState parses a wire-format message and places the result in m.
  77  //
  78  // This method permits fine-grained control over the unmarshaler.
  79  // Most users should use [Unmarshal] instead.
  80  func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
  81  	if o.RecursionLimit == 0 {
  82  		o.RecursionLimit = protowire.DefaultRecursionLimit
  83  	}
  84  	return o.unmarshal(in.Buf, in.Message)
  85  }
  86  
  87  // unmarshal is a centralized function that all unmarshal operations go through.
  88  // For profiling purposes, avoid changing the name of this function or
  89  // introducing other code paths for unmarshal that do not go through this.
  90  func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
  91  	if o.Resolver == nil {
  92  		o.Resolver = protoregistry.GlobalTypes
  93  	}
  94  	if !o.Merge {
  95  		Reset(m.Interface())
  96  	}
  97  	allowPartial := o.AllowPartial
  98  	o.Merge = true
  99  	o.AllowPartial = true
 100  	methods := protoMethods(m)
 101  	if methods != nil && methods.Unmarshal != nil &&
 102  		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
 103  		in := protoiface.UnmarshalInput{
 104  			Message:  m,
 105  			Buf:      b,
 106  			Resolver: o.Resolver,
 107  			Depth:    o.RecursionLimit,
 108  		}
 109  		if o.DiscardUnknown {
 110  			in.Flags |= protoiface.UnmarshalDiscardUnknown
 111  		}
 112  
 113  		if !allowPartial {
 114  			// This does not affect how current unmarshal functions work, it just allows them
 115  			// to record this for lazy the decoding case.
 116  			in.Flags |= protoiface.UnmarshalCheckRequired
 117  		}
 118  		if o.NoLazyDecoding {
 119  			in.Flags |= protoiface.UnmarshalNoLazyDecoding
 120  		}
 121  
 122  		out, err = methods.Unmarshal(in)
 123  	} else {
 124  		if o.RecursionLimit--; o.RecursionLimit < 0 {
 125  			return out, errRecursionDepth
 126  		}
 127  		err = o.unmarshalMessageSlow(b, m)
 128  	}
 129  	if err != nil {
 130  		return out, err
 131  	}
 132  	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
 133  		return out, nil
 134  	}
 135  	return out, checkInitialized(m)
 136  }
 137  
 138  func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
 139  	_, err := o.unmarshal(b, m)
 140  	return err
 141  }
 142  
 143  func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
 144  	md := m.Descriptor()
 145  	if messageset.IsMessageSet(md) {
 146  		return o.unmarshalMessageSet(b, m)
 147  	}
 148  	fields := md.Fields()
 149  	for len(b) > 0 {
 150  		// Parse the tag (field number and wire type).
 151  		num, wtyp, tagLen := protowire.ConsumeTag(b)
 152  		if tagLen < 0 {
 153  			return errDecode
 154  		}
 155  		if num > protowire.MaxValidNumber {
 156  			return errDecode
 157  		}
 158  
 159  		// Find the field descriptor for this field number.
 160  		fd := fields.ByNumber(num)
 161  		if fd == nil && md.ExtensionRanges().Has(num) {
 162  			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
 163  			if err != nil && err != protoregistry.NotFound {
 164  				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
 165  			}
 166  			if extType != nil {
 167  				fd = extType.TypeDescriptor()
 168  			}
 169  		}
 170  		var err error
 171  		if fd == nil {
 172  			err = errUnknown
 173  		}
 174  
 175  		// Parse the field value.
 176  		var valLen int
 177  		switch {
 178  		case err != nil:
 179  		case fd.IsList():
 180  			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
 181  		case fd.IsMap():
 182  			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
 183  		default:
 184  			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
 185  		}
 186  		if err != nil {
 187  			if err != errUnknown {
 188  				return err
 189  			}
 190  			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
 191  			if valLen < 0 {
 192  				return errDecode
 193  			}
 194  			if !o.DiscardUnknown {
 195  				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
 196  			}
 197  		}
 198  		b = b[tagLen+valLen:]
 199  	}
 200  	return nil
 201  }
 202  
 203  func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
 204  	v, n, err := o.unmarshalScalar(b, wtyp, fd)
 205  	if err != nil {
 206  		return 0, err
 207  	}
 208  	switch fd.Kind() {
 209  	case protoreflect.GroupKind, protoreflect.MessageKind:
 210  		m2 := m.Mutable(fd).Message()
 211  		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
 212  			return n, err
 213  		}
 214  	default:
 215  		// Non-message scalars replace the previous value.
 216  		m.Set(fd, v)
 217  	}
 218  	return n, nil
 219  }
 220  
 221  func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
 222  	if o.RecursionLimit--; o.RecursionLimit < 0 {
 223  		return 0, errRecursionDepth
 224  	}
 225  	if wtyp != protowire.BytesType {
 226  		return 0, errUnknown
 227  	}
 228  	b, n = protowire.ConsumeBytes(b)
 229  	if n < 0 {
 230  		return 0, errDecode
 231  	}
 232  	var (
 233  		keyField = fd.MapKey()
 234  		valField = fd.MapValue()
 235  		key      protoreflect.Value
 236  		val      protoreflect.Value
 237  		haveKey  bool
 238  		haveVal  bool
 239  	)
 240  	switch valField.Kind() {
 241  	case protoreflect.GroupKind, protoreflect.MessageKind:
 242  		val = mapv.NewValue()
 243  	}
 244  	// Map entries are represented as a two-element message with fields
 245  	// containing the key and value.
 246  	for len(b) > 0 {
 247  		num, wtyp, n := protowire.ConsumeTag(b)
 248  		if n < 0 {
 249  			return 0, errDecode
 250  		}
 251  		if num > protowire.MaxValidNumber {
 252  			return 0, errDecode
 253  		}
 254  		b = b[n:]
 255  		err = errUnknown
 256  		switch num {
 257  		case genid.MapEntry_Key_field_number:
 258  			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
 259  			if err != nil {
 260  				break
 261  			}
 262  			haveKey = true
 263  		case genid.MapEntry_Value_field_number:
 264  			var v protoreflect.Value
 265  			v, n, err = o.unmarshalScalar(b, wtyp, valField)
 266  			if err != nil {
 267  				break
 268  			}
 269  			switch valField.Kind() {
 270  			case protoreflect.GroupKind, protoreflect.MessageKind:
 271  				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
 272  					return 0, err
 273  				}
 274  			default:
 275  				val = v
 276  			}
 277  			haveVal = true
 278  		}
 279  		if err == errUnknown {
 280  			n = protowire.ConsumeFieldValue(num, wtyp, b)
 281  			if n < 0 {
 282  				return 0, errDecode
 283  			}
 284  		} else if err != nil {
 285  			return 0, err
 286  		}
 287  		b = b[n:]
 288  	}
 289  	// Every map entry should have entries for key and value, but this is not strictly required.
 290  	if !haveKey {
 291  		key = keyField.Default()
 292  	}
 293  	if !haveVal {
 294  		switch valField.Kind() {
 295  		case protoreflect.GroupKind, protoreflect.MessageKind:
 296  		default:
 297  			val = valField.Default()
 298  		}
 299  	}
 300  	mapv.Set(key.MapKey(), val)
 301  	return n, nil
 302  }
 303  
 304  // errUnknown is used internally to indicate fields which should be added
 305  // to the unknown field set of a message. It is never returned from an exported
 306  // function.
 307  var errUnknown = errors.New("BUG: internal error (unknown)")
 308  
 309  var errDecode = errors.New("cannot parse invalid wire-format data")
 310  
 311  var errRecursionDepth = errors.New("exceeded maximum recursion depth")
 312