bufferreader.go raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Helper code for parsing a protocol buffer
   6  
   7  package protolazy
   8  
   9  import (
  10  	"errors"
  11  	"fmt"
  12  	"io"
  13  
  14  	"google.golang.org/protobuf/encoding/protowire"
  15  )
  16  
  17  // BufferReader is a structure encapsulating a protobuf and a current position
  18  type BufferReader struct {
  19  	Buf []byte
  20  	Pos int
  21  }
  22  
  23  // NewBufferReader creates a new BufferRead from a protobuf
  24  func NewBufferReader(buf []byte) BufferReader {
  25  	return BufferReader{Buf: buf, Pos: 0}
  26  }
  27  
  28  var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
  29  var errOverflow = errors.New("proto: integer overflow")
  30  
  31  func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
  32  	i := b.Pos
  33  	l := len(b.Buf)
  34  
  35  	for shift := uint(0); shift < 64; shift += 7 {
  36  		if i >= l {
  37  			err = io.ErrUnexpectedEOF
  38  			return
  39  		}
  40  		v := b.Buf[i]
  41  		i++
  42  		x |= (uint64(v) & 0x7F) << shift
  43  		if v < 0x80 {
  44  			b.Pos = i
  45  			return
  46  		}
  47  	}
  48  
  49  	// The number is too large to represent in a 64-bit value.
  50  	err = errOverflow
  51  	return
  52  }
  53  
  54  // decodeVarint decodes a varint at the current position
  55  func (b *BufferReader) DecodeVarint() (x uint64, err error) {
  56  	i := b.Pos
  57  	buf := b.Buf
  58  
  59  	if i >= len(buf) {
  60  		return 0, io.ErrUnexpectedEOF
  61  	} else if buf[i] < 0x80 {
  62  		b.Pos++
  63  		return uint64(buf[i]), nil
  64  	} else if len(buf)-i < 10 {
  65  		return b.DecodeVarintSlow()
  66  	}
  67  
  68  	var v uint64
  69  	// we already checked the first byte
  70  	x = uint64(buf[i]) & 127
  71  	i++
  72  
  73  	v = uint64(buf[i])
  74  	i++
  75  	x |= (v & 127) << 7
  76  	if v < 128 {
  77  		goto done
  78  	}
  79  
  80  	v = uint64(buf[i])
  81  	i++
  82  	x |= (v & 127) << 14
  83  	if v < 128 {
  84  		goto done
  85  	}
  86  
  87  	v = uint64(buf[i])
  88  	i++
  89  	x |= (v & 127) << 21
  90  	if v < 128 {
  91  		goto done
  92  	}
  93  
  94  	v = uint64(buf[i])
  95  	i++
  96  	x |= (v & 127) << 28
  97  	if v < 128 {
  98  		goto done
  99  	}
 100  
 101  	v = uint64(buf[i])
 102  	i++
 103  	x |= (v & 127) << 35
 104  	if v < 128 {
 105  		goto done
 106  	}
 107  
 108  	v = uint64(buf[i])
 109  	i++
 110  	x |= (v & 127) << 42
 111  	if v < 128 {
 112  		goto done
 113  	}
 114  
 115  	v = uint64(buf[i])
 116  	i++
 117  	x |= (v & 127) << 49
 118  	if v < 128 {
 119  		goto done
 120  	}
 121  
 122  	v = uint64(buf[i])
 123  	i++
 124  	x |= (v & 127) << 56
 125  	if v < 128 {
 126  		goto done
 127  	}
 128  
 129  	v = uint64(buf[i])
 130  	i++
 131  	x |= (v & 127) << 63
 132  	if v < 128 {
 133  		goto done
 134  	}
 135  
 136  	return 0, errOverflow
 137  
 138  done:
 139  	b.Pos = i
 140  	return
 141  }
 142  
 143  // decodeVarint32 decodes a varint32 at the current position
 144  func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
 145  	i := b.Pos
 146  	buf := b.Buf
 147  
 148  	if i >= len(buf) {
 149  		return 0, io.ErrUnexpectedEOF
 150  	} else if buf[i] < 0x80 {
 151  		b.Pos++
 152  		return uint32(buf[i]), nil
 153  	} else if len(buf)-i < 5 {
 154  		v, err := b.DecodeVarintSlow()
 155  		return uint32(v), err
 156  	}
 157  
 158  	var v uint32
 159  	// we already checked the first byte
 160  	x = uint32(buf[i]) & 127
 161  	i++
 162  
 163  	v = uint32(buf[i])
 164  	i++
 165  	x |= (v & 127) << 7
 166  	if v < 128 {
 167  		goto done
 168  	}
 169  
 170  	v = uint32(buf[i])
 171  	i++
 172  	x |= (v & 127) << 14
 173  	if v < 128 {
 174  		goto done
 175  	}
 176  
 177  	v = uint32(buf[i])
 178  	i++
 179  	x |= (v & 127) << 21
 180  	if v < 128 {
 181  		goto done
 182  	}
 183  
 184  	v = uint32(buf[i])
 185  	i++
 186  	x |= (v & 127) << 28
 187  	if v < 128 {
 188  		goto done
 189  	}
 190  
 191  	return 0, errOverflow
 192  
 193  done:
 194  	b.Pos = i
 195  	return
 196  }
 197  
 198  // skipValue skips a value in the protobuf, based on the specified tag
 199  func (b *BufferReader) SkipValue(tag uint32) (err error) {
 200  	wireType := tag & 0x7
 201  	switch protowire.Type(wireType) {
 202  	case protowire.VarintType:
 203  		err = b.SkipVarint()
 204  	case protowire.Fixed64Type:
 205  		err = b.SkipFixed64()
 206  	case protowire.BytesType:
 207  		var n uint32
 208  		n, err = b.DecodeVarint32()
 209  		if err == nil {
 210  			err = b.Skip(int(n))
 211  		}
 212  	case protowire.StartGroupType:
 213  		err = b.SkipGroup(tag)
 214  	case protowire.Fixed32Type:
 215  		err = b.SkipFixed32()
 216  	default:
 217  		err = fmt.Errorf("Unexpected wire type (%d)", wireType)
 218  	}
 219  	return
 220  }
 221  
 222  // skipGroup skips a group with the specified tag.  It executes efficiently using a tag stack
 223  func (b *BufferReader) SkipGroup(tag uint32) (err error) {
 224  	tagStack := make([]uint32, 0, 16)
 225  	tagStack = append(tagStack, tag)
 226  	var n uint32
 227  	for len(tagStack) > 0 {
 228  		tag, err = b.DecodeVarint32()
 229  		if err != nil {
 230  			return err
 231  		}
 232  		switch protowire.Type(tag & 0x7) {
 233  		case protowire.VarintType:
 234  			err = b.SkipVarint()
 235  		case protowire.Fixed64Type:
 236  			err = b.Skip(8)
 237  		case protowire.BytesType:
 238  			n, err = b.DecodeVarint32()
 239  			if err == nil {
 240  				err = b.Skip(int(n))
 241  			}
 242  		case protowire.StartGroupType:
 243  			tagStack = append(tagStack, tag)
 244  		case protowire.Fixed32Type:
 245  			err = b.SkipFixed32()
 246  		case protowire.EndGroupType:
 247  			if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
 248  				tagStack = tagStack[:len(tagStack)-1]
 249  			} else {
 250  				err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
 251  					protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
 252  			}
 253  		}
 254  		if err != nil {
 255  			return err
 256  		}
 257  	}
 258  	return nil
 259  }
 260  
 261  // skipVarint effiently skips a varint
 262  func (b *BufferReader) SkipVarint() (err error) {
 263  	i := b.Pos
 264  
 265  	if len(b.Buf)-i < 10 {
 266  		// Use DecodeVarintSlow() to check for buffer overflow, but ignore result
 267  		if _, err := b.DecodeVarintSlow(); err != nil {
 268  			return err
 269  		}
 270  		return nil
 271  	}
 272  
 273  	if b.Buf[i] < 0x80 {
 274  		goto out
 275  	}
 276  	i++
 277  
 278  	if b.Buf[i] < 0x80 {
 279  		goto out
 280  	}
 281  	i++
 282  
 283  	if b.Buf[i] < 0x80 {
 284  		goto out
 285  	}
 286  	i++
 287  
 288  	if b.Buf[i] < 0x80 {
 289  		goto out
 290  	}
 291  	i++
 292  
 293  	if b.Buf[i] < 0x80 {
 294  		goto out
 295  	}
 296  	i++
 297  
 298  	if b.Buf[i] < 0x80 {
 299  		goto out
 300  	}
 301  	i++
 302  
 303  	if b.Buf[i] < 0x80 {
 304  		goto out
 305  	}
 306  	i++
 307  
 308  	if b.Buf[i] < 0x80 {
 309  		goto out
 310  	}
 311  	i++
 312  
 313  	if b.Buf[i] < 0x80 {
 314  		goto out
 315  	}
 316  	i++
 317  
 318  	if b.Buf[i] < 0x80 {
 319  		goto out
 320  	}
 321  	return errOverflow
 322  
 323  out:
 324  	b.Pos = i + 1
 325  	return nil
 326  }
 327  
 328  // skip skips the specified number of bytes
 329  func (b *BufferReader) Skip(n int) (err error) {
 330  	if len(b.Buf) < b.Pos+n {
 331  		return io.ErrUnexpectedEOF
 332  	}
 333  	b.Pos += n
 334  	return
 335  }
 336  
 337  // skipFixed64 skips a fixed64
 338  func (b *BufferReader) SkipFixed64() (err error) {
 339  	return b.Skip(8)
 340  }
 341  
 342  // skipFixed32 skips a fixed32
 343  func (b *BufferReader) SkipFixed32() (err error) {
 344  	return b.Skip(4)
 345  }
 346  
 347  // skipBytes skips a set of bytes
 348  func (b *BufferReader) SkipBytes() (err error) {
 349  	n, err := b.DecodeVarint32()
 350  	if err != nil {
 351  		return err
 352  	}
 353  	return b.Skip(int(n))
 354  }
 355  
 356  // Done returns whether we are at the end of the protobuf
 357  func (b *BufferReader) Done() bool {
 358  	return b.Pos == len(b.Buf)
 359  }
 360  
 361  // Remaining returns how many bytes remain
 362  func (b *BufferReader) Remaining() int {
 363  	return len(b.Buf) - b.Pos
 364  }
 365