decode.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"math/bits"
   9  
  10  	"google.golang.org/protobuf/encoding/protowire"
  11  	"google.golang.org/protobuf/internal/errors"
  12  	"google.golang.org/protobuf/internal/flags"
  13  	"google.golang.org/protobuf/proto"
  14  	"google.golang.org/protobuf/reflect/protoreflect"
  15  	"google.golang.org/protobuf/reflect/protoregistry"
  16  	"google.golang.org/protobuf/runtime/protoiface"
  17  )
  18  
  19  var errDecode = errors.New("cannot parse invalid wire-format data")
  20  var errRecursionDepth = errors.New("exceeded maximum recursion depth")
  21  
  22  type unmarshalOptions struct {
  23  	flags    protoiface.UnmarshalInputFlags
  24  	resolver interface {
  25  		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  26  		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  27  	}
  28  	depth int
  29  }
  30  
  31  func (o unmarshalOptions) Options() proto.UnmarshalOptions {
  32  	return proto.UnmarshalOptions{
  33  		Merge:          true,
  34  		AllowPartial:   true,
  35  		DiscardUnknown: o.DiscardUnknown(),
  36  		Resolver:       o.resolver,
  37  
  38  		NoLazyDecoding: o.NoLazyDecoding(),
  39  	}
  40  }
  41  
  42  func (o unmarshalOptions) DiscardUnknown() bool {
  43  	return o.flags&protoiface.UnmarshalDiscardUnknown != 0
  44  }
  45  
  46  func (o unmarshalOptions) AliasBuffer() bool { return o.flags&protoiface.UnmarshalAliasBuffer != 0 }
  47  func (o unmarshalOptions) Validated() bool   { return o.flags&protoiface.UnmarshalValidated != 0 }
  48  func (o unmarshalOptions) NoLazyDecoding() bool {
  49  	return o.flags&protoiface.UnmarshalNoLazyDecoding != 0
  50  }
  51  
  52  func (o unmarshalOptions) CanBeLazy() bool {
  53  	if o.resolver != protoregistry.GlobalTypes {
  54  		return false
  55  	}
  56  	// We ignore the UnmarshalInvalidateSizeCache even though it's not in the default set
  57  	return (o.flags & ^(protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated | protoiface.UnmarshalCheckRequired)) == 0
  58  }
  59  
  60  var lazyUnmarshalOptions = unmarshalOptions{
  61  	resolver: protoregistry.GlobalTypes,
  62  
  63  	flags: protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated,
  64  
  65  	depth: protowire.DefaultRecursionLimit,
  66  }
  67  
  68  type unmarshalOutput struct {
  69  	n           int // number of bytes consumed
  70  	initialized bool
  71  }
  72  
  73  // unmarshal is protoreflect.Methods.Unmarshal.
  74  func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
  75  	var p pointer
  76  	if ms, ok := in.Message.(*messageState); ok {
  77  		p = ms.pointer()
  78  	} else {
  79  		p = in.Message.(*messageReflectWrapper).pointer()
  80  	}
  81  	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
  82  		flags:    in.Flags,
  83  		resolver: in.Resolver,
  84  		depth:    in.Depth,
  85  	})
  86  	var flags protoiface.UnmarshalOutputFlags
  87  	if out.initialized {
  88  		flags |= protoiface.UnmarshalInitialized
  89  	}
  90  	return protoiface.UnmarshalOutput{
  91  		Flags: flags,
  92  	}, err
  93  }
  94  
  95  // errUnknown is returned during unmarshaling to indicate a parse error that
  96  // should result in a field being placed in the unknown fields section (for example,
  97  // when the wire type doesn't match) as opposed to the entire unmarshal operation
  98  // failing (for example, when a field extends past the available input).
  99  //
 100  // This is a sentinel error which should never be visible to the user.
 101  var errUnknown = errors.New("unknown")
 102  
 103  func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
 104  	mi.init()
 105  	if opts.depth--; opts.depth < 0 {
 106  		return out, errRecursionDepth
 107  	}
 108  	if flags.ProtoLegacy && mi.isMessageSet {
 109  		return unmarshalMessageSet(mi, b, p, opts)
 110  	}
 111  
 112  	lazyDecoding := LazyEnabled() // default
 113  	if opts.NoLazyDecoding() {
 114  		lazyDecoding = false // explicitly disabled
 115  	}
 116  	if mi.lazyOffset.IsValid() && lazyDecoding {
 117  		return mi.unmarshalPointerLazy(b, p, groupTag, opts)
 118  	}
 119  	return mi.unmarshalPointerEager(b, p, groupTag, opts)
 120  }
 121  
 122  // unmarshalPointerEager is the message unmarshalling function for all messages that are not lazy.
 123  // The corresponding function for Lazy is in google_lazy.go.
 124  func (mi *MessageInfo) unmarshalPointerEager(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
 125  
 126  	initialized := true
 127  	var requiredMask uint64
 128  	var exts *map[int32]ExtensionField
 129  
 130  	var presence presence
 131  	if mi.presenceOffset.IsValid() {
 132  		presence = p.Apply(mi.presenceOffset).PresenceInfo()
 133  	}
 134  
 135  	start := len(b)
 136  	for len(b) > 0 {
 137  		// Parse the tag (field number and wire type).
 138  		var tag uint64
 139  		if b[0] < 0x80 {
 140  			tag = uint64(b[0])
 141  			b = b[1:]
 142  		} else if len(b) >= 2 && b[1] < 128 {
 143  			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
 144  			b = b[2:]
 145  		} else {
 146  			var n int
 147  			tag, n = protowire.ConsumeVarint(b)
 148  			if n < 0 {
 149  				return out, errDecode
 150  			}
 151  			b = b[n:]
 152  		}
 153  		var num protowire.Number
 154  		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
 155  			return out, errDecode
 156  		} else {
 157  			num = protowire.Number(n)
 158  		}
 159  		wtyp := protowire.Type(tag & 7)
 160  
 161  		if wtyp == protowire.EndGroupType {
 162  			if num != groupTag {
 163  				return out, errDecode
 164  			}
 165  			groupTag = 0
 166  			break
 167  		}
 168  
 169  		var f *coderFieldInfo
 170  		if int(num) < len(mi.denseCoderFields) {
 171  			f = mi.denseCoderFields[num]
 172  		} else {
 173  			f = mi.coderFields[num]
 174  		}
 175  		var n int
 176  		err := errUnknown
 177  		switch {
 178  		case f != nil:
 179  			if f.funcs.unmarshal == nil {
 180  				break
 181  			}
 182  			var o unmarshalOutput
 183  			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
 184  			n = o.n
 185  			if err != nil {
 186  				break
 187  			}
 188  			requiredMask |= f.validation.requiredBit
 189  			if f.funcs.isInit != nil && !o.initialized {
 190  				initialized = false
 191  			}
 192  
 193  			if f.presenceIndex != noPresence {
 194  				presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
 195  			}
 196  
 197  		default:
 198  			// Possible extension.
 199  			if exts == nil && mi.extensionOffset.IsValid() {
 200  				exts = p.Apply(mi.extensionOffset).Extensions()
 201  				if *exts == nil {
 202  					*exts = make(map[int32]ExtensionField)
 203  				}
 204  			}
 205  			if exts == nil {
 206  				break
 207  			}
 208  			var o unmarshalOutput
 209  			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
 210  			if err != nil {
 211  				break
 212  			}
 213  			n = o.n
 214  			if !o.initialized {
 215  				initialized = false
 216  			}
 217  		}
 218  		if err != nil {
 219  			if err != errUnknown {
 220  				return out, err
 221  			}
 222  			n = protowire.ConsumeFieldValue(num, wtyp, b)
 223  			if n < 0 {
 224  				return out, errDecode
 225  			}
 226  			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
 227  				u := mi.mutableUnknownBytes(p)
 228  				*u = protowire.AppendTag(*u, num, wtyp)
 229  				*u = append(*u, b[:n]...)
 230  			}
 231  		}
 232  		b = b[n:]
 233  	}
 234  	if groupTag != 0 {
 235  		return out, errDecode
 236  	}
 237  	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
 238  		initialized = false
 239  	}
 240  	if initialized {
 241  		out.initialized = true
 242  	}
 243  	out.n = start - len(b)
 244  	return out, nil
 245  }
 246  
 247  func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
 248  	x := exts[int32(num)]
 249  	xt := x.Type()
 250  	if xt == nil {
 251  		var err error
 252  		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
 253  		if err != nil {
 254  			if err == protoregistry.NotFound {
 255  				return out, errUnknown
 256  			}
 257  			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
 258  		}
 259  	}
 260  	xi := getExtensionFieldInfo(xt)
 261  	if xi.funcs.unmarshal == nil {
 262  		return out, errUnknown
 263  	}
 264  	if flags.LazyUnmarshalExtensions {
 265  		if opts.CanBeLazy() && x.canLazy(xt) {
 266  			out, valid := skipExtension(b, xi, num, wtyp, opts)
 267  			switch valid {
 268  			case ValidationValid:
 269  				if out.initialized {
 270  					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
 271  					exts[int32(num)] = x
 272  					return out, nil
 273  				}
 274  			case ValidationInvalid:
 275  				return out, errDecode
 276  			case ValidationUnknown:
 277  			}
 278  		}
 279  	}
 280  	ival := x.Value()
 281  	if !ival.IsValid() && xi.unmarshalNeedsValue {
 282  		// Create a new message, list, or map value to fill in.
 283  		// For enums, create a prototype value to let the unmarshal func know the
 284  		// concrete type.
 285  		ival = xt.New()
 286  	}
 287  	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
 288  	if err != nil {
 289  		return out, err
 290  	}
 291  	if xi.funcs.isInit == nil {
 292  		out.initialized = true
 293  	}
 294  	x.Set(xt, v)
 295  	exts[int32(num)] = x
 296  	return out, nil
 297  }
 298  
 299  func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
 300  	if xi.validation.mi == nil {
 301  		return out, ValidationUnknown
 302  	}
 303  	xi.validation.mi.init()
 304  	switch xi.validation.typ {
 305  	case validationTypeMessage:
 306  		if wtyp != protowire.BytesType {
 307  			return out, ValidationUnknown
 308  		}
 309  		v, n := protowire.ConsumeBytes(b)
 310  		if n < 0 {
 311  			return out, ValidationUnknown
 312  		}
 313  
 314  		if opts.Validated() {
 315  			out.initialized = true
 316  			out.n = n
 317  			return out, ValidationValid
 318  		}
 319  
 320  		out, st := xi.validation.mi.validate(v, 0, opts)
 321  		out.n = n
 322  		return out, st
 323  	case validationTypeGroup:
 324  		if wtyp != protowire.StartGroupType {
 325  			return out, ValidationUnknown
 326  		}
 327  		out, st := xi.validation.mi.validate(b, num, opts)
 328  		return out, st
 329  	default:
 330  		return out, ValidationUnknown
 331  	}
 332  }
 333