lazy.go raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"fmt"
   9  	"math/bits"
  10  	"os"
  11  	"reflect"
  12  	"sort"
  13  	"sync/atomic"
  14  
  15  	"google.golang.org/protobuf/encoding/protowire"
  16  	"google.golang.org/protobuf/internal/errors"
  17  	"google.golang.org/protobuf/internal/protolazy"
  18  	"google.golang.org/protobuf/reflect/protoreflect"
  19  	preg "google.golang.org/protobuf/reflect/protoregistry"
  20  	piface "google.golang.org/protobuf/runtime/protoiface"
  21  )
  22  
  23  var enableLazy int32 = func() int32 {
  24  	if os.Getenv("GOPROTODEBUG") == "nolazy" {
  25  		return 0
  26  	}
  27  	return 1
  28  }()
  29  
  30  // EnableLazyUnmarshal enables lazy unmarshaling.
  31  func EnableLazyUnmarshal(enable bool) {
  32  	if enable {
  33  		atomic.StoreInt32(&enableLazy, 1)
  34  		return
  35  	}
  36  	atomic.StoreInt32(&enableLazy, 0)
  37  }
  38  
  39  // LazyEnabled reports whether lazy unmarshalling is currently enabled.
  40  func LazyEnabled() bool {
  41  	return atomic.LoadInt32(&enableLazy) != 0
  42  }
  43  
  44  // UnmarshalField unmarshals a field in a message.
  45  func UnmarshalField(m interface{}, num protowire.Number) {
  46  	switch m := m.(type) {
  47  	case *messageState:
  48  		m.messageInfo().lazyUnmarshal(m.pointer(), num)
  49  	case *messageReflectWrapper:
  50  		m.messageInfo().lazyUnmarshal(m.pointer(), num)
  51  	default:
  52  		panic(fmt.Sprintf("unsupported wrapper type %T", m))
  53  	}
  54  }
  55  
  56  func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) {
  57  	var f *coderFieldInfo
  58  	if int(num) < len(mi.denseCoderFields) {
  59  		f = mi.denseCoderFields[num]
  60  	} else {
  61  		f = mi.coderFields[num]
  62  	}
  63  	if f == nil {
  64  		panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num))
  65  	}
  66  	lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
  67  	start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num))
  68  	if !found && multipleEntries == nil {
  69  		panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num))
  70  	}
  71  	// The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races.
  72  	// Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil.
  73  	fp := pointerOfValue(reflect.New(f.ft))
  74  	if multipleEntries != nil {
  75  		for _, entry := range multipleEntries {
  76  			mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags())
  77  		}
  78  	} else {
  79  		mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags())
  80  	}
  81  	p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem())
  82  }
  83  
  84  func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error {
  85  	opts := lazyUnmarshalOptions
  86  	opts.flags |= flags
  87  	for len(b) > 0 {
  88  		// Parse the tag (field number and wire type).
  89  		var tag uint64
  90  		if b[0] < 0x80 {
  91  			tag = uint64(b[0])
  92  			b = b[1:]
  93  		} else if len(b) >= 2 && b[1] < 128 {
  94  			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  95  			b = b[2:]
  96  		} else {
  97  			var n int
  98  			tag, n = protowire.ConsumeVarint(b)
  99  			if n < 0 {
 100  				return errors.New("invalid wire data")
 101  			}
 102  			b = b[n:]
 103  		}
 104  		var num protowire.Number
 105  		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
 106  			return errors.New("invalid wire data")
 107  		} else {
 108  			num = protowire.Number(n)
 109  		}
 110  		wtyp := protowire.Type(tag & 7)
 111  		if num == f.num {
 112  			o, err := f.funcs.unmarshal(b, p, wtyp, f, opts)
 113  			if err == nil {
 114  				b = b[o.n:]
 115  				continue
 116  			}
 117  			if err != errUnknown {
 118  				return err
 119  			}
 120  		}
 121  		n := protowire.ConsumeFieldValue(num, wtyp, b)
 122  		if n < 0 {
 123  			return errors.New("invalid wire data")
 124  		}
 125  		b = b[n:]
 126  	}
 127  	return nil
 128  }
 129  
 130  func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
 131  	fmi := f.validation.mi
 132  	if fmi == nil {
 133  		fd := mi.Desc.Fields().ByNumber(f.num)
 134  		if fd == nil {
 135  			return out, ValidationUnknown
 136  		}
 137  		messageName := fd.Message().FullName()
 138  		messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
 139  		if err != nil {
 140  			return out, ValidationUnknown
 141  		}
 142  		var ok bool
 143  		fmi, ok = messageType.(*MessageInfo)
 144  		if !ok {
 145  			return out, ValidationUnknown
 146  		}
 147  	}
 148  	fmi.init()
 149  	switch f.validation.typ {
 150  	case validationTypeMessage:
 151  		if wtyp != protowire.BytesType {
 152  			return out, ValidationWrongWireType
 153  		}
 154  		v, n := protowire.ConsumeBytes(b)
 155  		if n < 0 {
 156  			return out, ValidationInvalid
 157  		}
 158  		out, st := fmi.validate(v, 0, opts)
 159  		out.n = n
 160  		return out, st
 161  	case validationTypeGroup:
 162  		if wtyp != protowire.StartGroupType {
 163  			return out, ValidationWrongWireType
 164  		}
 165  		out, st := fmi.validate(b, f.num, opts)
 166  		return out, st
 167  	default:
 168  		return out, ValidationUnknown
 169  	}
 170  }
 171  
 172  // unmarshalPointerLazy is similar to unmarshalPointerEager, but it
 173  // specifically handles lazy unmarshalling.  it expects lazyOffset and
 174  // presenceOffset to both be valid.
 175  func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
 176  	initialized := true
 177  	var requiredMask uint64
 178  	var lazy **protolazy.XXX_lazyUnmarshalInfo
 179  	var presence presence
 180  	var lazyIndex []protolazy.IndexEntry
 181  	var lastNum protowire.Number
 182  	outOfOrder := false
 183  	lazyDecode := false
 184  	presence = p.Apply(mi.presenceOffset).PresenceInfo()
 185  	lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
 186  	if !presence.AnyPresent(mi.presenceSize) {
 187  		if opts.CanBeLazy() {
 188  			// If the message contains existing data, we need to merge into it.
 189  			// Lazy unmarshaling doesn't merge, so only enable it when the
 190  			// message is empty (has no presence bitmap).
 191  			lazyDecode = true
 192  			if *lazy == nil {
 193  				*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
 194  			}
 195  			(*lazy).SetUnmarshalFlags(opts.flags)
 196  			if !opts.AliasBuffer() {
 197  				// Make a copy of the buffer for lazy unmarshaling.
 198  				// Set the AliasBuffer flag so recursive unmarshal
 199  				// operations reuse the copy.
 200  				b = append([]byte{}, b...)
 201  				opts.flags |= piface.UnmarshalAliasBuffer
 202  			}
 203  			(*lazy).SetBuffer(b)
 204  		}
 205  	}
 206  	// Track special handling of lazy fields.
 207  	//
 208  	// In the common case, all fields are lazyValidateOnly (and lazyFields remains nil).
 209  	// In the event that validation for a field fails, this map tracks handling of the field.
 210  	type lazyAction uint8
 211  	const (
 212  		lazyValidateOnly   lazyAction = iota // validate the field only
 213  		lazyUnmarshalNow                     // eagerly unmarshal the field
 214  		lazyUnmarshalLater                   // unmarshal the field after the message is fully processed
 215  	)
 216  	var lazyFields map[*coderFieldInfo]lazyAction
 217  	var exts *map[int32]ExtensionField
 218  	start := len(b)
 219  	pos := 0
 220  	for len(b) > 0 {
 221  		// Parse the tag (field number and wire type).
 222  		var tag uint64
 223  		if b[0] < 0x80 {
 224  			tag = uint64(b[0])
 225  			b = b[1:]
 226  		} else if len(b) >= 2 && b[1] < 128 {
 227  			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
 228  			b = b[2:]
 229  		} else {
 230  			var n int
 231  			tag, n = protowire.ConsumeVarint(b)
 232  			if n < 0 {
 233  				return out, errDecode
 234  			}
 235  			b = b[n:]
 236  		}
 237  		var num protowire.Number
 238  		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
 239  			return out, errors.New("invalid field number")
 240  		} else {
 241  			num = protowire.Number(n)
 242  		}
 243  		wtyp := protowire.Type(tag & 7)
 244  
 245  		if wtyp == protowire.EndGroupType {
 246  			if num != groupTag {
 247  				return out, errors.New("mismatching end group marker")
 248  			}
 249  			groupTag = 0
 250  			break
 251  		}
 252  
 253  		var f *coderFieldInfo
 254  		if int(num) < len(mi.denseCoderFields) {
 255  			f = mi.denseCoderFields[num]
 256  		} else {
 257  			f = mi.coderFields[num]
 258  		}
 259  		var n int
 260  		err := errUnknown
 261  		discardUnknown := false
 262  	Field:
 263  		switch {
 264  		case f != nil:
 265  			if f.funcs.unmarshal == nil {
 266  				break
 267  			}
 268  			if f.isLazy && lazyDecode {
 269  				switch {
 270  				case lazyFields == nil || lazyFields[f] == lazyValidateOnly:
 271  					// Attempt to validate this field and leave it for later lazy unmarshaling.
 272  					o, valid := mi.skipField(b, f, wtyp, opts)
 273  					switch valid {
 274  					case ValidationValid:
 275  						// Skip over the valid field and continue.
 276  						err = nil
 277  						presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
 278  						requiredMask |= f.validation.requiredBit
 279  						if !o.initialized {
 280  							initialized = false
 281  						}
 282  						n = o.n
 283  						break Field
 284  					case ValidationInvalid:
 285  						return out, errors.New("invalid proto wire format")
 286  					case ValidationWrongWireType:
 287  						break Field
 288  					case ValidationUnknown:
 289  						if lazyFields == nil {
 290  							lazyFields = make(map[*coderFieldInfo]lazyAction)
 291  						}
 292  						if presence.Present(f.presenceIndex) {
 293  							// We were unable to determine if the field is valid or not,
 294  							// and we've already skipped over at least one instance of this
 295  							// field. Clear the presence bit (so if we stop decoding early,
 296  							// we don't leave a partially-initialized field around) and flag
 297  							// the field for unmarshaling before we return.
 298  							presence.ClearPresent(f.presenceIndex)
 299  							lazyFields[f] = lazyUnmarshalLater
 300  							discardUnknown = true
 301  							break Field
 302  						} else {
 303  							// We were unable to determine if the field is valid or not,
 304  							// but this is the first time we've seen it. Flag it as needing
 305  							// eager unmarshaling and fall through to the eager unmarshal case below.
 306  							lazyFields[f] = lazyUnmarshalNow
 307  						}
 308  					}
 309  				case lazyFields[f] == lazyUnmarshalLater:
 310  					// This field will be unmarshaled in a separate pass below.
 311  					// Skip over it here.
 312  					discardUnknown = true
 313  					break Field
 314  				default:
 315  					// Eagerly unmarshal the field.
 316  				}
 317  			}
 318  			if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) {
 319  				if p.Apply(f.offset).AtomicGetPointer().IsNil() {
 320  					mi.lazyUnmarshal(p, f.num)
 321  				}
 322  			}
 323  			var o unmarshalOutput
 324  			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
 325  			n = o.n
 326  			if err != nil {
 327  				break
 328  			}
 329  			requiredMask |= f.validation.requiredBit
 330  			if f.funcs.isInit != nil && !o.initialized {
 331  				initialized = false
 332  			}
 333  			if f.presenceIndex != noPresence {
 334  				presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
 335  			}
 336  		default:
 337  			// Possible extension.
 338  			if exts == nil && mi.extensionOffset.IsValid() {
 339  				exts = p.Apply(mi.extensionOffset).Extensions()
 340  				if *exts == nil {
 341  					*exts = make(map[int32]ExtensionField)
 342  				}
 343  			}
 344  			if exts == nil {
 345  				break
 346  			}
 347  			var o unmarshalOutput
 348  			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
 349  			if err != nil {
 350  				break
 351  			}
 352  			n = o.n
 353  			if !o.initialized {
 354  				initialized = false
 355  			}
 356  		}
 357  		if err != nil {
 358  			if err != errUnknown {
 359  				return out, err
 360  			}
 361  			n = protowire.ConsumeFieldValue(num, wtyp, b)
 362  			if n < 0 {
 363  				return out, errDecode
 364  			}
 365  			if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
 366  				u := mi.mutableUnknownBytes(p)
 367  				*u = protowire.AppendTag(*u, num, wtyp)
 368  				*u = append(*u, b[:n]...)
 369  			}
 370  		}
 371  		b = b[n:]
 372  		end := start - len(b)
 373  		if lazyDecode && f != nil && f.isLazy {
 374  			if num != lastNum {
 375  				lazyIndex = append(lazyIndex, protolazy.IndexEntry{
 376  					FieldNum: uint32(num),
 377  					Start:    uint32(pos),
 378  					End:      uint32(end),
 379  				})
 380  			} else {
 381  				i := len(lazyIndex) - 1
 382  				lazyIndex[i].End = uint32(end)
 383  				lazyIndex[i].MultipleContiguous = true
 384  			}
 385  		}
 386  		if num < lastNum {
 387  			outOfOrder = true
 388  		}
 389  		pos = end
 390  		lastNum = num
 391  	}
 392  	if groupTag != 0 {
 393  		return out, errors.New("missing end group marker")
 394  	}
 395  	if lazyFields != nil {
 396  		// Some fields failed validation, and now need to be unmarshaled.
 397  		for f, action := range lazyFields {
 398  			if action != lazyUnmarshalLater {
 399  				continue
 400  			}
 401  			initialized = false
 402  			if *lazy == nil {
 403  				*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
 404  			}
 405  			if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil {
 406  				return out, err
 407  			}
 408  			presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
 409  		}
 410  	}
 411  	if lazyDecode {
 412  		if outOfOrder {
 413  			sort.Slice(lazyIndex, func(i, j int) bool {
 414  				return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum ||
 415  					(lazyIndex[i].FieldNum == lazyIndex[j].FieldNum &&
 416  						lazyIndex[i].Start < lazyIndex[j].Start)
 417  			})
 418  		}
 419  		if *lazy == nil {
 420  			*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
 421  		}
 422  
 423  		(*lazy).SetIndex(lazyIndex)
 424  	}
 425  	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
 426  		initialized = false
 427  	}
 428  	if initialized {
 429  		out.initialized = true
 430  	}
 431  	out.n = start - len(b)
 432  	return out, nil
 433  }
 434