objectid.go raw

   1  // Copyright (C) MongoDB, Inc. 2017-present.
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
   4  // not use this file except in compliance with the License. You may obtain
   5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
   6  //
   7  // Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
   8  // See THIRD-PARTY-NOTICES for original license terms.
   9  
  10  package primitive
  11  
  12  import (
  13  	"crypto/rand"
  14  	"encoding"
  15  	"encoding/binary"
  16  	"encoding/hex"
  17  	"encoding/json"
  18  	"errors"
  19  	"fmt"
  20  	"io"
  21  	"sync/atomic"
  22  	"time"
  23  )
  24  
  25  // ErrInvalidHex indicates that a hex string cannot be converted to an ObjectID.
  26  var ErrInvalidHex = errors.New("the provided hex string is not a valid ObjectID")
  27  
  28  // ObjectID is the BSON ObjectID type.
  29  type ObjectID [12]byte
  30  
  31  // NilObjectID is the zero value for ObjectID.
  32  var NilObjectID ObjectID
  33  
  34  var objectIDCounter = readRandomUint32()
  35  var processUnique = processUniqueBytes()
  36  
  37  var _ encoding.TextMarshaler = ObjectID{}
  38  var _ encoding.TextUnmarshaler = &ObjectID{}
  39  
  40  // NewObjectID generates a new ObjectID.
  41  func NewObjectID() ObjectID {
  42  	return NewObjectIDFromTimestamp(time.Now())
  43  }
  44  
  45  // NewObjectIDFromTimestamp generates a new ObjectID based on the given time.
  46  func NewObjectIDFromTimestamp(timestamp time.Time) ObjectID {
  47  	var b [12]byte
  48  
  49  	binary.BigEndian.PutUint32(b[0:4], uint32(timestamp.Unix()))
  50  	copy(b[4:9], processUnique[:])
  51  	putUint24(b[9:12], atomic.AddUint32(&objectIDCounter, 1))
  52  
  53  	return b
  54  }
  55  
  56  // Timestamp extracts the time part of the ObjectId.
  57  func (id ObjectID) Timestamp() time.Time {
  58  	unixSecs := binary.BigEndian.Uint32(id[0:4])
  59  	return time.Unix(int64(unixSecs), 0).UTC()
  60  }
  61  
  62  // Hex returns the hex encoding of the ObjectID as a string.
  63  func (id ObjectID) Hex() string {
  64  	var buf [24]byte
  65  	hex.Encode(buf[:], id[:])
  66  	return string(buf[:])
  67  }
  68  
  69  func (id ObjectID) String() string {
  70  	return fmt.Sprintf("ObjectID(%q)", id.Hex())
  71  }
  72  
  73  // IsZero returns true if id is the empty ObjectID.
  74  func (id ObjectID) IsZero() bool {
  75  	return id == NilObjectID
  76  }
  77  
  78  // ObjectIDFromHex creates a new ObjectID from a hex string. It returns an error if the hex string is not a
  79  // valid ObjectID.
  80  func ObjectIDFromHex(s string) (ObjectID, error) {
  81  	if len(s) != 24 {
  82  		return NilObjectID, ErrInvalidHex
  83  	}
  84  
  85  	var oid [12]byte
  86  	_, err := hex.Decode(oid[:], []byte(s))
  87  	if err != nil {
  88  		return NilObjectID, err
  89  	}
  90  
  91  	return oid, nil
  92  }
  93  
  94  // IsValidObjectID returns true if the provided hex string represents a valid ObjectID and false if not.
  95  //
  96  // Deprecated: Use ObjectIDFromHex and check the error instead.
  97  func IsValidObjectID(s string) bool {
  98  	_, err := ObjectIDFromHex(s)
  99  	return err == nil
 100  }
 101  
 102  // MarshalText returns the ObjectID as UTF-8-encoded text. Implementing this allows us to use ObjectID
 103  // as a map key when marshalling JSON. See https://pkg.go.dev/encoding#TextMarshaler
 104  func (id ObjectID) MarshalText() ([]byte, error) {
 105  	return []byte(id.Hex()), nil
 106  }
 107  
 108  // UnmarshalText populates the byte slice with the ObjectID. Implementing this allows us to use ObjectID
 109  // as a map key when unmarshalling JSON. See https://pkg.go.dev/encoding#TextUnmarshaler
 110  func (id *ObjectID) UnmarshalText(b []byte) error {
 111  	oid, err := ObjectIDFromHex(string(b))
 112  	if err != nil {
 113  		return err
 114  	}
 115  	*id = oid
 116  	return nil
 117  }
 118  
 119  // MarshalJSON returns the ObjectID as a string
 120  func (id ObjectID) MarshalJSON() ([]byte, error) {
 121  	return json.Marshal(id.Hex())
 122  }
 123  
 124  // UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice is 24 bytes long, it
 125  // will be populated with the hex representation of the ObjectID. If the byte slice is twelve bytes
 126  // long, it will be populated with the BSON representation of the ObjectID. This method also accepts empty strings and
 127  // decodes them as NilObjectID. For any other inputs, an error will be returned.
 128  func (id *ObjectID) UnmarshalJSON(b []byte) error {
 129  	// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer ObjectID field
 130  	// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
 131  	// enter the UnmarshalJSON hook.
 132  	if string(b) == "null" {
 133  		return nil
 134  	}
 135  
 136  	var err error
 137  	switch len(b) {
 138  	case 12:
 139  		copy(id[:], b)
 140  	default:
 141  		// Extended JSON
 142  		var res interface{}
 143  		err := json.Unmarshal(b, &res)
 144  		if err != nil {
 145  			return err
 146  		}
 147  		str, ok := res.(string)
 148  		if !ok {
 149  			m, ok := res.(map[string]interface{})
 150  			if !ok {
 151  				return errors.New("not an extended JSON ObjectID")
 152  			}
 153  			oid, ok := m["$oid"]
 154  			if !ok {
 155  				return errors.New("not an extended JSON ObjectID")
 156  			}
 157  			str, ok = oid.(string)
 158  			if !ok {
 159  				return errors.New("not an extended JSON ObjectID")
 160  			}
 161  		}
 162  
 163  		// An empty string is not a valid ObjectID, but we treat it as a special value that decodes as NilObjectID.
 164  		if len(str) == 0 {
 165  			copy(id[:], NilObjectID[:])
 166  			return nil
 167  		}
 168  
 169  		if len(str) != 24 {
 170  			return fmt.Errorf("cannot unmarshal into an ObjectID, the length must be 24 but it is %d", len(str))
 171  		}
 172  
 173  		_, err = hex.Decode(id[:], []byte(str))
 174  		if err != nil {
 175  			return err
 176  		}
 177  	}
 178  
 179  	return err
 180  }
 181  
 182  func processUniqueBytes() [5]byte {
 183  	var b [5]byte
 184  	_, err := io.ReadFull(rand.Reader, b[:])
 185  	if err != nil {
 186  		panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
 187  	}
 188  
 189  	return b
 190  }
 191  
 192  func readRandomUint32() uint32 {
 193  	var b [4]byte
 194  	_, err := io.ReadFull(rand.Reader, b[:])
 195  	if err != nil {
 196  		panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
 197  	}
 198  
 199  	return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
 200  }
 201  
 202  func putUint24(b []byte, v uint32) {
 203  	b[0] = byte(v >> 16)
 204  	b[1] = byte(v >> 8)
 205  	b[2] = byte(v)
 206  }
 207