lazy.go raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Package protolazy contains internal data structures for lazy message decoding.
   6  package protolazy
   7  
   8  import (
   9  	"fmt"
  10  	"sort"
  11  
  12  	"google.golang.org/protobuf/encoding/protowire"
  13  	piface "google.golang.org/protobuf/runtime/protoiface"
  14  )
  15  
  16  // IndexEntry is the structure for an index of the fields in a message of a
  17  // proto (not descending to sub-messages)
  18  type IndexEntry struct {
  19  	FieldNum uint32
  20  	// first byte of this tag/field
  21  	Start uint32
  22  	// first byte after a contiguous sequence of bytes for this tag/field, which could
  23  	// include a single encoding of the field, or multiple encodings for the field
  24  	End uint32
  25  	// True if this protobuf segment includes multiple encodings of the field
  26  	MultipleContiguous bool
  27  }
  28  
  29  // XXX_lazyUnmarshalInfo has information about a particular lazily decoded message
  30  //
  31  // Deprecated: Do not use. This will be deleted in the near future.
  32  type XXX_lazyUnmarshalInfo struct {
  33  	// Index of fields and their positions in the protobuf for this
  34  	// message.  Make index be a pointer to a slice so it can be updated
  35  	// atomically.  The index pointer is only set once (lazily when/if
  36  	// the index is first needed), and must always be SET and LOADED
  37  	// ATOMICALLY.
  38  	index *[]IndexEntry
  39  	// The protobuf associated with this lazily decoded message.  It is
  40  	// only set during proto.Unmarshal().  It doesn't need to be set and
  41  	// loaded atomically, since any simultaneous set (Unmarshal) and read
  42  	// (during a get) would already be a race in the app code.
  43  	Protobuf []byte
  44  	// The flags present when Unmarshal was originally called for this particular message
  45  	unmarshalFlags piface.UnmarshalInputFlags
  46  }
  47  
  48  // The Buffer and SetBuffer methods let v2/internal/impl interact with
  49  // XXX_lazyUnmarshalInfo via an interface, to avoid an import cycle.
  50  
  51  // Buffer returns the lazy unmarshal buffer.
  52  //
  53  // Deprecated: Do not use. This will be deleted in the near future.
  54  func (lazy *XXX_lazyUnmarshalInfo) Buffer() []byte {
  55  	return lazy.Protobuf
  56  }
  57  
  58  // SetBuffer sets the lazy unmarshal buffer.
  59  //
  60  // Deprecated: Do not use. This will be deleted in the near future.
  61  func (lazy *XXX_lazyUnmarshalInfo) SetBuffer(b []byte) {
  62  	lazy.Protobuf = b
  63  }
  64  
  65  // SetUnmarshalFlags is called to set a copy of the original unmarshalInputFlags.
  66  // The flags should reflect how Unmarshal was called.
  67  func (lazy *XXX_lazyUnmarshalInfo) SetUnmarshalFlags(f piface.UnmarshalInputFlags) {
  68  	lazy.unmarshalFlags = f
  69  }
  70  
  71  // UnmarshalFlags returns the original unmarshalInputFlags.
  72  func (lazy *XXX_lazyUnmarshalInfo) UnmarshalFlags() piface.UnmarshalInputFlags {
  73  	return lazy.unmarshalFlags
  74  }
  75  
  76  // AllowedPartial returns true if the user originally unmarshalled this message with
  77  // AllowPartial set to true
  78  func (lazy *XXX_lazyUnmarshalInfo) AllowedPartial() bool {
  79  	return (lazy.unmarshalFlags & piface.UnmarshalCheckRequired) == 0
  80  }
  81  
  82  func protoFieldNumber(tag uint32) uint32 {
  83  	return tag >> 3
  84  }
  85  
  86  // buildIndex builds an index of the specified protobuf, return the index
  87  // array and an error.
  88  func buildIndex(buf []byte) ([]IndexEntry, error) {
  89  	index := make([]IndexEntry, 0, 16)
  90  	var lastProtoFieldNum uint32
  91  	var outOfOrder bool
  92  
  93  	var r BufferReader = NewBufferReader(buf)
  94  
  95  	for !r.Done() {
  96  		var tag uint32
  97  		var err error
  98  		var curPos = r.Pos
  99  		// INLINED: tag, err = r.DecodeVarint32()
 100  		{
 101  			i := r.Pos
 102  			buf := r.Buf
 103  
 104  			if i >= len(buf) {
 105  				return nil, errOutOfBounds
 106  			} else if buf[i] < 0x80 {
 107  				r.Pos++
 108  				tag = uint32(buf[i])
 109  			} else if r.Remaining() < 5 {
 110  				var v uint64
 111  				v, err = r.DecodeVarintSlow()
 112  				tag = uint32(v)
 113  			} else {
 114  				var v uint32
 115  				// we already checked the first byte
 116  				tag = uint32(buf[i]) & 127
 117  				i++
 118  
 119  				v = uint32(buf[i])
 120  				i++
 121  				tag |= (v & 127) << 7
 122  				if v < 128 {
 123  					goto done
 124  				}
 125  
 126  				v = uint32(buf[i])
 127  				i++
 128  				tag |= (v & 127) << 14
 129  				if v < 128 {
 130  					goto done
 131  				}
 132  
 133  				v = uint32(buf[i])
 134  				i++
 135  				tag |= (v & 127) << 21
 136  				if v < 128 {
 137  					goto done
 138  				}
 139  
 140  				v = uint32(buf[i])
 141  				i++
 142  				tag |= (v & 127) << 28
 143  				if v < 128 {
 144  					goto done
 145  				}
 146  
 147  				return nil, errOutOfBounds
 148  
 149  			done:
 150  				r.Pos = i
 151  			}
 152  		}
 153  		// DONE: tag, err = r.DecodeVarint32()
 154  
 155  		fieldNum := protoFieldNumber(tag)
 156  		if fieldNum < lastProtoFieldNum {
 157  			outOfOrder = true
 158  		}
 159  
 160  		// Skip the current value -- will skip over an entire group as well.
 161  		// INLINED: err = r.SkipValue(tag)
 162  		wireType := tag & 0x7
 163  		switch protowire.Type(wireType) {
 164  		case protowire.VarintType:
 165  			// INLINED: err = r.SkipVarint()
 166  			i := r.Pos
 167  
 168  			if len(r.Buf)-i < 10 {
 169  				// Use DecodeVarintSlow() to skip while
 170  				// checking for buffer overflow, but ignore result
 171  				_, err = r.DecodeVarintSlow()
 172  				goto out2
 173  			}
 174  			if r.Buf[i] < 0x80 {
 175  				goto out
 176  			}
 177  			i++
 178  
 179  			if r.Buf[i] < 0x80 {
 180  				goto out
 181  			}
 182  			i++
 183  
 184  			if r.Buf[i] < 0x80 {
 185  				goto out
 186  			}
 187  			i++
 188  
 189  			if r.Buf[i] < 0x80 {
 190  				goto out
 191  			}
 192  			i++
 193  
 194  			if r.Buf[i] < 0x80 {
 195  				goto out
 196  			}
 197  			i++
 198  
 199  			if r.Buf[i] < 0x80 {
 200  				goto out
 201  			}
 202  			i++
 203  
 204  			if r.Buf[i] < 0x80 {
 205  				goto out
 206  			}
 207  			i++
 208  
 209  			if r.Buf[i] < 0x80 {
 210  				goto out
 211  			}
 212  			i++
 213  
 214  			if r.Buf[i] < 0x80 {
 215  				goto out
 216  			}
 217  			i++
 218  
 219  			if r.Buf[i] < 0x80 {
 220  				goto out
 221  			}
 222  			return nil, errOverflow
 223  		out:
 224  			r.Pos = i + 1
 225  			// DONE: err = r.SkipVarint()
 226  		case protowire.Fixed64Type:
 227  			err = r.SkipFixed64()
 228  		case protowire.BytesType:
 229  			var n uint32
 230  			n, err = r.DecodeVarint32()
 231  			if err == nil {
 232  				err = r.Skip(int(n))
 233  			}
 234  		case protowire.StartGroupType:
 235  			err = r.SkipGroup(tag)
 236  		case protowire.Fixed32Type:
 237  			err = r.SkipFixed32()
 238  		default:
 239  			err = fmt.Errorf("Unexpected wire type (%d)", wireType)
 240  		}
 241  		// DONE: err = r.SkipValue(tag)
 242  
 243  	out2:
 244  		if err != nil {
 245  			return nil, err
 246  		}
 247  		if fieldNum != lastProtoFieldNum {
 248  			index = append(index, IndexEntry{FieldNum: fieldNum,
 249  				Start: uint32(curPos),
 250  				End:   uint32(r.Pos)},
 251  			)
 252  		} else {
 253  			index[len(index)-1].End = uint32(r.Pos)
 254  			index[len(index)-1].MultipleContiguous = true
 255  		}
 256  		lastProtoFieldNum = fieldNum
 257  	}
 258  	if outOfOrder {
 259  		sort.Slice(index, func(i, j int) bool {
 260  			return index[i].FieldNum < index[j].FieldNum ||
 261  				(index[i].FieldNum == index[j].FieldNum &&
 262  					index[i].Start < index[j].Start)
 263  		})
 264  	}
 265  	return index, nil
 266  }
 267  
 268  func (lazy *XXX_lazyUnmarshalInfo) SizeField(num uint32) (size int) {
 269  	start, end, found, _, multipleEntries := lazy.FindFieldInProto(num)
 270  	if multipleEntries != nil {
 271  		for _, entry := range multipleEntries {
 272  			size += int(entry.End - entry.Start)
 273  		}
 274  		return size
 275  	}
 276  	if !found {
 277  		return 0
 278  	}
 279  	return int(end - start)
 280  }
 281  
 282  func (lazy *XXX_lazyUnmarshalInfo) AppendField(b []byte, num uint32) ([]byte, bool) {
 283  	start, end, found, _, multipleEntries := lazy.FindFieldInProto(num)
 284  	if multipleEntries != nil {
 285  		for _, entry := range multipleEntries {
 286  			b = append(b, lazy.Protobuf[entry.Start:entry.End]...)
 287  		}
 288  		return b, true
 289  	}
 290  	if !found {
 291  		return nil, false
 292  	}
 293  	b = append(b, lazy.Protobuf[start:end]...)
 294  	return b, true
 295  }
 296  
 297  func (lazy *XXX_lazyUnmarshalInfo) SetIndex(index []IndexEntry) {
 298  	atomicStoreIndex(&lazy.index, &index)
 299  }
 300  
 301  // FindFieldInProto looks for field fieldNum in lazyUnmarshalInfo information
 302  // (including protobuf), returns startOffset/endOffset/found.
 303  func (lazy *XXX_lazyUnmarshalInfo) FindFieldInProto(fieldNum uint32) (start, end uint32, found, multipleContiguous bool, multipleEntries []IndexEntry) {
 304  	if lazy.Protobuf == nil {
 305  		// There is no backing protobuf for this message -- it was made from a builder
 306  		return 0, 0, false, false, nil
 307  	}
 308  	index := atomicLoadIndex(&lazy.index)
 309  	if index == nil {
 310  		r, err := buildIndex(lazy.Protobuf)
 311  		if err != nil {
 312  			panic(fmt.Sprintf("findFieldInfo: error building index when looking for field %d: %v", fieldNum, err))
 313  		}
 314  		// lazy.index is a pointer to the slice returned by BuildIndex
 315  		index = &r
 316  		atomicStoreIndex(&lazy.index, index)
 317  	}
 318  	return lookupField(index, fieldNum)
 319  }
 320  
 321  // lookupField returns the offset at which the indicated field starts using
 322  // the index, offset immediately after field ends (including all instances of
 323  // a repeated field), and bools indicating if field was found and if there
 324  // are multiple encodings of the field in the byte range.
 325  //
 326  // To hande the uncommon case where there are repeated encodings for the same
 327  // field which are not consecutive in the protobuf (so we need to returns
 328  // multiple start/end offsets), we also return a slice multipleEntries.  If
 329  // multipleEntries is non-nil, then multiple entries were found, and the
 330  // values in the slice should be used, rather than start/end/found.
 331  func lookupField(indexp *[]IndexEntry, fieldNum uint32) (start, end uint32, found bool, multipleContiguous bool, multipleEntries []IndexEntry) {
 332  	// The pointer indexp to the index was already loaded atomically.
 333  	// The slice is uniquely associated with the pointer, so it doesn't
 334  	// need to be loaded atomically.
 335  	index := *indexp
 336  	for i, entry := range index {
 337  		if fieldNum == entry.FieldNum {
 338  			if i < len(index)-1 && entry.FieldNum == index[i+1].FieldNum {
 339  				// Handle the uncommon case where there are
 340  				// repeated entries for the same field which
 341  				// are not contiguous in the protobuf.
 342  				multiple := make([]IndexEntry, 1, 2)
 343  				multiple[0] = IndexEntry{fieldNum, entry.Start, entry.End, entry.MultipleContiguous}
 344  				i++
 345  				for i < len(index) && index[i].FieldNum == fieldNum {
 346  					multiple = append(multiple, IndexEntry{fieldNum, index[i].Start, index[i].End, index[i].MultipleContiguous})
 347  					i++
 348  				}
 349  				return 0, 0, false, false, multiple
 350  
 351  			}
 352  			return entry.Start, entry.End, true, entry.MultipleContiguous, nil
 353  		}
 354  		if fieldNum < entry.FieldNum {
 355  			return 0, 0, false, false, nil
 356  		}
 357  	}
 358  	return 0, 0, false, false, nil
 359  }
 360