value_writer.go raw

   1  // Copyright (C) MongoDB, Inc. 2017-present.
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
   4  // not use this file except in compliance with the License. You may obtain
   5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
   6  
   7  package bsonrw
   8  
   9  import (
  10  	"errors"
  11  	"fmt"
  12  	"io"
  13  	"math"
  14  	"strconv"
  15  	"strings"
  16  	"sync"
  17  
  18  	"go.mongodb.org/mongo-driver/bson/bsontype"
  19  	"go.mongodb.org/mongo-driver/bson/primitive"
  20  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
  21  )
  22  
  23  var _ ValueWriter = (*valueWriter)(nil)
  24  
  25  var vwPool = sync.Pool{
  26  	New: func() interface{} {
  27  		return new(valueWriter)
  28  	},
  29  }
  30  
  31  func putValueWriter(vw *valueWriter) {
  32  	if vw != nil {
  33  		vw.w = nil // don't leak the writer
  34  		vwPool.Put(vw)
  35  	}
  36  }
  37  
  38  // BSONValueWriterPool is a pool for BSON ValueWriters.
  39  //
  40  // Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
  41  type BSONValueWriterPool struct {
  42  	pool sync.Pool
  43  }
  44  
  45  // NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON.
  46  //
  47  // Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
  48  func NewBSONValueWriterPool() *BSONValueWriterPool {
  49  	return &BSONValueWriterPool{
  50  		pool: sync.Pool{
  51  			New: func() interface{} {
  52  				return new(valueWriter)
  53  			},
  54  		},
  55  	}
  56  }
  57  
  58  // Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination.
  59  //
  60  // Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
  61  func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
  62  	vw := bvwp.pool.Get().(*valueWriter)
  63  
  64  	// TODO: Having to call reset here with the same buffer doesn't really make sense.
  65  	vw.reset(vw.buf)
  66  	vw.buf = vw.buf[:0]
  67  	vw.w = w
  68  	return vw
  69  }
  70  
  71  // GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination.
  72  //
  73  // Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
  74  func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher {
  75  	vw := bvwp.Get(w).(*valueWriter)
  76  	vw.push(mElement)
  77  	return vw
  78  }
  79  
  80  // Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
  81  // happens and ok will be false.
  82  //
  83  // Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
  84  func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
  85  	bvw, ok := vw.(*valueWriter)
  86  	if !ok {
  87  		return false
  88  	}
  89  
  90  	bvwp.pool.Put(bvw)
  91  	return true
  92  }
  93  
  94  // This is here so that during testing we can change it and not require
  95  // allocating a 4GB slice.
  96  var maxSize = math.MaxInt32
  97  
  98  var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer")
  99  
 100  type errMaxDocumentSizeExceeded struct {
 101  	size int64
 102  }
 103  
 104  func (mdse errMaxDocumentSizeExceeded) Error() string {
 105  	return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
 106  }
 107  
 108  type vwMode int
 109  
 110  const (
 111  	_ vwMode = iota
 112  	vwTopLevel
 113  	vwDocument
 114  	vwArray
 115  	vwValue
 116  	vwElement
 117  	vwCodeWithScope
 118  )
 119  
 120  func (vm vwMode) String() string {
 121  	var str string
 122  
 123  	switch vm {
 124  	case vwTopLevel:
 125  		str = "TopLevel"
 126  	case vwDocument:
 127  		str = "DocumentMode"
 128  	case vwArray:
 129  		str = "ArrayMode"
 130  	case vwValue:
 131  		str = "ValueMode"
 132  	case vwElement:
 133  		str = "ElementMode"
 134  	case vwCodeWithScope:
 135  		str = "CodeWithScopeMode"
 136  	default:
 137  		str = "UnknownMode"
 138  	}
 139  
 140  	return str
 141  }
 142  
 143  type vwState struct {
 144  	mode   mode
 145  	key    string
 146  	arrkey int
 147  	start  int32
 148  }
 149  
 150  type valueWriter struct {
 151  	w   io.Writer
 152  	buf []byte
 153  
 154  	stack []vwState
 155  	frame int64
 156  }
 157  
 158  func (vw *valueWriter) advanceFrame() {
 159  	vw.frame++
 160  	if vw.frame >= int64(len(vw.stack)) {
 161  		vw.stack = append(vw.stack, vwState{})
 162  	}
 163  }
 164  
 165  func (vw *valueWriter) push(m mode) {
 166  	vw.advanceFrame()
 167  
 168  	// Clean the stack
 169  	vw.stack[vw.frame] = vwState{mode: m}
 170  
 171  	switch m {
 172  	case mDocument, mArray, mCodeWithScope:
 173  		vw.reserveLength() // WARN: this is not needed
 174  	}
 175  }
 176  
 177  func (vw *valueWriter) reserveLength() {
 178  	vw.stack[vw.frame].start = int32(len(vw.buf))
 179  	vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
 180  }
 181  
 182  func (vw *valueWriter) pop() {
 183  	switch vw.stack[vw.frame].mode {
 184  	case mElement, mValue:
 185  		vw.frame--
 186  	case mDocument, mArray, mCodeWithScope:
 187  		vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
 188  	}
 189  }
 190  
 191  // NewBSONValueWriter creates a ValueWriter that writes BSON to w.
 192  //
 193  // This ValueWriter will only write entire documents to the io.Writer and it
 194  // will buffer the document as it is built.
 195  func NewBSONValueWriter(w io.Writer) (ValueWriter, error) {
 196  	if w == nil {
 197  		return nil, errNilWriter
 198  	}
 199  	return newValueWriter(w), nil
 200  }
 201  
 202  func newValueWriter(w io.Writer) *valueWriter {
 203  	vw := new(valueWriter)
 204  	stack := make([]vwState, 1, 5)
 205  	stack[0] = vwState{mode: mTopLevel}
 206  	vw.w = w
 207  	vw.stack = stack
 208  
 209  	return vw
 210  }
 211  
 212  // TODO: only used in tests
 213  func newValueWriterFromSlice(buf []byte) *valueWriter {
 214  	vw := new(valueWriter)
 215  	stack := make([]vwState, 1, 5)
 216  	stack[0] = vwState{mode: mTopLevel}
 217  	vw.stack = stack
 218  	vw.buf = buf
 219  
 220  	return vw
 221  }
 222  
 223  func (vw *valueWriter) reset(buf []byte) {
 224  	if vw.stack == nil {
 225  		vw.stack = make([]vwState, 1, 5)
 226  	}
 227  	vw.stack = vw.stack[:1]
 228  	vw.stack[0] = vwState{mode: mTopLevel}
 229  	vw.buf = buf
 230  	vw.frame = 0
 231  	vw.w = nil
 232  }
 233  
 234  func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
 235  	te := TransitionError{
 236  		name:        name,
 237  		current:     vw.stack[vw.frame].mode,
 238  		destination: destination,
 239  		modes:       modes,
 240  		action:      "write",
 241  	}
 242  	if vw.frame != 0 {
 243  		te.parent = vw.stack[vw.frame-1].mode
 244  	}
 245  	return te
 246  }
 247  
 248  func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
 249  	frame := &vw.stack[vw.frame]
 250  	switch frame.mode {
 251  	case mElement:
 252  		key := frame.key
 253  		if !isValidCString(key) {
 254  			return errors.New("BSON element key cannot contain null bytes")
 255  		}
 256  		vw.appendHeader(t, key)
 257  	case mValue:
 258  		vw.appendIntHeader(t, frame.arrkey)
 259  	default:
 260  		modes := []mode{mElement, mValue}
 261  		if addmodes != nil {
 262  			modes = append(modes, addmodes...)
 263  		}
 264  		return vw.invalidTransitionError(destination, callerName, modes)
 265  	}
 266  
 267  	return nil
 268  }
 269  
 270  func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error {
 271  	if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
 272  		return err
 273  	}
 274  	vw.buf = append(vw.buf, b...)
 275  	vw.pop()
 276  	return nil
 277  }
 278  
 279  func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
 280  	if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil {
 281  		return nil, err
 282  	}
 283  
 284  	vw.push(mArray)
 285  
 286  	return vw, nil
 287  }
 288  
 289  func (vw *valueWriter) WriteBinary(b []byte) error {
 290  	return vw.WriteBinaryWithSubtype(b, 0x00)
 291  }
 292  
 293  func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
 294  	if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil {
 295  		return err
 296  	}
 297  
 298  	vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
 299  	vw.pop()
 300  	return nil
 301  }
 302  
 303  func (vw *valueWriter) WriteBoolean(b bool) error {
 304  	if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil {
 305  		return err
 306  	}
 307  
 308  	vw.buf = bsoncore.AppendBoolean(vw.buf, b)
 309  	vw.pop()
 310  	return nil
 311  }
 312  
 313  func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
 314  	if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
 315  		return nil, err
 316  	}
 317  
 318  	// CodeWithScope is a different than other types because we need an extra
 319  	// frame on the stack. In the EndDocument code, we write the document
 320  	// length, pop, write the code with scope length, and pop. To simplify the
 321  	// pop code, we push a spacer frame that we'll always jump over.
 322  	vw.push(mCodeWithScope)
 323  	vw.buf = bsoncore.AppendString(vw.buf, code)
 324  	vw.push(mSpacer)
 325  	vw.push(mDocument)
 326  
 327  	return vw, nil
 328  }
 329  
 330  func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
 331  	if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil {
 332  		return err
 333  	}
 334  
 335  	vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
 336  	vw.pop()
 337  	return nil
 338  }
 339  
 340  func (vw *valueWriter) WriteDateTime(dt int64) error {
 341  	if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil {
 342  		return err
 343  	}
 344  
 345  	vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
 346  	vw.pop()
 347  	return nil
 348  }
 349  
 350  func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error {
 351  	if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil {
 352  		return err
 353  	}
 354  
 355  	vw.buf = bsoncore.AppendDecimal128(vw.buf, d128)
 356  	vw.pop()
 357  	return nil
 358  }
 359  
 360  func (vw *valueWriter) WriteDouble(f float64) error {
 361  	if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil {
 362  		return err
 363  	}
 364  
 365  	vw.buf = bsoncore.AppendDouble(vw.buf, f)
 366  	vw.pop()
 367  	return nil
 368  }
 369  
 370  func (vw *valueWriter) WriteInt32(i32 int32) error {
 371  	if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil {
 372  		return err
 373  	}
 374  
 375  	vw.buf = bsoncore.AppendInt32(vw.buf, i32)
 376  	vw.pop()
 377  	return nil
 378  }
 379  
 380  func (vw *valueWriter) WriteInt64(i64 int64) error {
 381  	if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil {
 382  		return err
 383  	}
 384  
 385  	vw.buf = bsoncore.AppendInt64(vw.buf, i64)
 386  	vw.pop()
 387  	return nil
 388  }
 389  
 390  func (vw *valueWriter) WriteJavascript(code string) error {
 391  	if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil {
 392  		return err
 393  	}
 394  
 395  	vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
 396  	vw.pop()
 397  	return nil
 398  }
 399  
 400  func (vw *valueWriter) WriteMaxKey() error {
 401  	if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil {
 402  		return err
 403  	}
 404  
 405  	vw.pop()
 406  	return nil
 407  }
 408  
 409  func (vw *valueWriter) WriteMinKey() error {
 410  	if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil {
 411  		return err
 412  	}
 413  
 414  	vw.pop()
 415  	return nil
 416  }
 417  
 418  func (vw *valueWriter) WriteNull() error {
 419  	if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil {
 420  		return err
 421  	}
 422  
 423  	vw.pop()
 424  	return nil
 425  }
 426  
 427  func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
 428  	if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil {
 429  		return err
 430  	}
 431  
 432  	vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
 433  	vw.pop()
 434  	return nil
 435  }
 436  
 437  func (vw *valueWriter) WriteRegex(pattern string, options string) error {
 438  	if !isValidCString(pattern) || !isValidCString(options) {
 439  		return errors.New("BSON regex values cannot contain null bytes")
 440  	}
 441  	if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
 442  		return err
 443  	}
 444  
 445  	vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
 446  	vw.pop()
 447  	return nil
 448  }
 449  
 450  func (vw *valueWriter) WriteString(s string) error {
 451  	if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil {
 452  		return err
 453  	}
 454  
 455  	vw.buf = bsoncore.AppendString(vw.buf, s)
 456  	vw.pop()
 457  	return nil
 458  }
 459  
 460  func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
 461  	if vw.stack[vw.frame].mode == mTopLevel {
 462  		vw.reserveLength()
 463  		return vw, nil
 464  	}
 465  	if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
 466  		return nil, err
 467  	}
 468  
 469  	vw.push(mDocument)
 470  	return vw, nil
 471  }
 472  
 473  func (vw *valueWriter) WriteSymbol(symbol string) error {
 474  	if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil {
 475  		return err
 476  	}
 477  
 478  	vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
 479  	vw.pop()
 480  	return nil
 481  }
 482  
 483  func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
 484  	if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil {
 485  		return err
 486  	}
 487  
 488  	vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
 489  	vw.pop()
 490  	return nil
 491  }
 492  
 493  func (vw *valueWriter) WriteUndefined() error {
 494  	if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil {
 495  		return err
 496  	}
 497  
 498  	vw.pop()
 499  	return nil
 500  }
 501  
 502  func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
 503  	switch vw.stack[vw.frame].mode {
 504  	case mTopLevel, mDocument:
 505  	default:
 506  		return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
 507  	}
 508  
 509  	vw.push(mElement)
 510  	vw.stack[vw.frame].key = key
 511  
 512  	return vw, nil
 513  }
 514  
 515  func (vw *valueWriter) WriteDocumentEnd() error {
 516  	switch vw.stack[vw.frame].mode {
 517  	case mTopLevel, mDocument:
 518  	default:
 519  		return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
 520  	}
 521  
 522  	vw.buf = append(vw.buf, 0x00)
 523  
 524  	err := vw.writeLength()
 525  	if err != nil {
 526  		return err
 527  	}
 528  
 529  	if vw.stack[vw.frame].mode == mTopLevel {
 530  		if err = vw.Flush(); err != nil {
 531  			return err
 532  		}
 533  	}
 534  
 535  	vw.pop()
 536  
 537  	if vw.stack[vw.frame].mode == mCodeWithScope {
 538  		// We ignore the error here because of the guarantee of writeLength.
 539  		// See the docs for writeLength for more info.
 540  		_ = vw.writeLength()
 541  		vw.pop()
 542  	}
 543  	return nil
 544  }
 545  
 546  func (vw *valueWriter) Flush() error {
 547  	if vw.w == nil {
 548  		return nil
 549  	}
 550  
 551  	if _, err := vw.w.Write(vw.buf); err != nil {
 552  		return err
 553  	}
 554  	// reset buffer
 555  	vw.buf = vw.buf[:0]
 556  	return nil
 557  }
 558  
 559  func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
 560  	if vw.stack[vw.frame].mode != mArray {
 561  		return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
 562  	}
 563  
 564  	arrkey := vw.stack[vw.frame].arrkey
 565  	vw.stack[vw.frame].arrkey++
 566  
 567  	vw.push(mValue)
 568  	vw.stack[vw.frame].arrkey = arrkey
 569  
 570  	return vw, nil
 571  }
 572  
 573  func (vw *valueWriter) WriteArrayEnd() error {
 574  	if vw.stack[vw.frame].mode != mArray {
 575  		return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
 576  	}
 577  
 578  	vw.buf = append(vw.buf, 0x00)
 579  
 580  	err := vw.writeLength()
 581  	if err != nil {
 582  		return err
 583  	}
 584  
 585  	vw.pop()
 586  	return nil
 587  }
 588  
 589  // NOTE: We assume that if we call writeLength more than once the same function
 590  // within the same function without altering the vw.buf that this method will
 591  // not return an error. If this changes ensure that the following methods are
 592  // updated:
 593  //
 594  // - WriteDocumentEnd
 595  func (vw *valueWriter) writeLength() error {
 596  	length := len(vw.buf)
 597  	if length > maxSize {
 598  		return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
 599  	}
 600  	frame := &vw.stack[vw.frame]
 601  	length = length - int(frame.start)
 602  	start := frame.start
 603  
 604  	_ = vw.buf[start+3] // BCE
 605  	vw.buf[start+0] = byte(length)
 606  	vw.buf[start+1] = byte(length >> 8)
 607  	vw.buf[start+2] = byte(length >> 16)
 608  	vw.buf[start+3] = byte(length >> 24)
 609  	return nil
 610  }
 611  
 612  func isValidCString(cs string) bool {
 613  	// Disallow the zero byte in a cstring because the zero byte is used as the
 614  	// terminating character.
 615  	//
 616  	// It's safe to check bytes instead of runes because all multibyte UTF-8
 617  	// code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e.
 618  	// 0) will never be part of a multibyte UTF-8 code point. This logic is the
 619  	// same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be
 620  	// inlined.
 621  	//
 622  	// https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127
 623  	return strings.IndexByte(cs, 0) == -1
 624  }
 625  
 626  // appendHeader is the same as bsoncore.AppendHeader but does not check if the
 627  // key is a valid C string since the caller has already checked for that.
 628  //
 629  // The caller of this function must check if key is a valid C string.
 630  func (vw *valueWriter) appendHeader(t bsontype.Type, key string) {
 631  	vw.buf = bsoncore.AppendType(vw.buf, t)
 632  	vw.buf = append(vw.buf, key...)
 633  	vw.buf = append(vw.buf, 0x00)
 634  }
 635  
 636  func (vw *valueWriter) appendIntHeader(t bsontype.Type, key int) {
 637  	vw.buf = bsoncore.AppendType(vw.buf, t)
 638  	vw.buf = strconv.AppendInt(vw.buf, int64(key), 10)
 639  	vw.buf = append(vw.buf, 0x00)
 640  }
 641