value_equal.go raw

   1  // Copyright 2022 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package protoreflect
   6  
   7  import (
   8  	"bytes"
   9  	"fmt"
  10  	"math"
  11  	"reflect"
  12  
  13  	"google.golang.org/protobuf/encoding/protowire"
  14  )
  15  
  16  // Equal reports whether v1 and v2 are recursively equal.
  17  //
  18  //   - Values of different types are always unequal.
  19  //
  20  //   - Bytes values are equal if they contain identical bytes.
  21  //     Empty bytes (regardless of nil-ness) are considered equal.
  22  //
  23  //   - Floating point values are equal if they contain the same value.
  24  //     Unlike the == operator, a NaN is equal to another NaN.
  25  //
  26  //   - Enums are equal if they contain the same number.
  27  //     Since [Value] does not contain an enum descriptor,
  28  //     enum values do not consider the type of the enum.
  29  //
  30  //   - Other scalar values are equal if they contain the same value.
  31  //
  32  //   - [Message] values are equal if they belong to the same message descriptor,
  33  //     have the same set of populated known and extension field values,
  34  //     and the same set of unknown fields values.
  35  //
  36  //   - [List] values are equal if they are the same length and
  37  //     each corresponding element is equal.
  38  //
  39  //   - [Map] values are equal if they have the same set of keys and
  40  //     the corresponding value for each key is equal.
  41  func (v1 Value) Equal(v2 Value) bool {
  42  	return equalValue(v1, v2)
  43  }
  44  
  45  func equalValue(x, y Value) bool {
  46  	eqType := x.typ == y.typ
  47  	switch x.typ {
  48  	case nilType:
  49  		return eqType
  50  	case boolType:
  51  		return eqType && x.Bool() == y.Bool()
  52  	case int32Type, int64Type:
  53  		return eqType && x.Int() == y.Int()
  54  	case uint32Type, uint64Type:
  55  		return eqType && x.Uint() == y.Uint()
  56  	case float32Type, float64Type:
  57  		return eqType && equalFloat(x.Float(), y.Float())
  58  	case stringType:
  59  		return eqType && x.String() == y.String()
  60  	case bytesType:
  61  		return eqType && bytes.Equal(x.Bytes(), y.Bytes())
  62  	case enumType:
  63  		return eqType && x.Enum() == y.Enum()
  64  	default:
  65  		switch x := x.Interface().(type) {
  66  		case Message:
  67  			y, ok := y.Interface().(Message)
  68  			return ok && equalMessage(x, y)
  69  		case List:
  70  			y, ok := y.Interface().(List)
  71  			return ok && equalList(x, y)
  72  		case Map:
  73  			y, ok := y.Interface().(Map)
  74  			return ok && equalMap(x, y)
  75  		default:
  76  			panic(fmt.Sprintf("unknown type: %T", x))
  77  		}
  78  	}
  79  }
  80  
  81  // equalFloat compares two floats, where NaNs are treated as equal.
  82  func equalFloat(x, y float64) bool {
  83  	if math.IsNaN(x) || math.IsNaN(y) {
  84  		return math.IsNaN(x) && math.IsNaN(y)
  85  	}
  86  	return x == y
  87  }
  88  
  89  // equalMessage compares two messages.
  90  func equalMessage(mx, my Message) bool {
  91  	if mx.Descriptor() != my.Descriptor() {
  92  		return false
  93  	}
  94  
  95  	nx := 0
  96  	equal := true
  97  	mx.Range(func(fd FieldDescriptor, vx Value) bool {
  98  		nx++
  99  		vy := my.Get(fd)
 100  		equal = my.Has(fd) && equalValue(vx, vy)
 101  		return equal
 102  	})
 103  	if !equal {
 104  		return false
 105  	}
 106  	ny := 0
 107  	my.Range(func(fd FieldDescriptor, vx Value) bool {
 108  		ny++
 109  		return true
 110  	})
 111  	if nx != ny {
 112  		return false
 113  	}
 114  
 115  	return equalUnknown(mx.GetUnknown(), my.GetUnknown())
 116  }
 117  
 118  // equalList compares two lists.
 119  func equalList(x, y List) bool {
 120  	if x.Len() != y.Len() {
 121  		return false
 122  	}
 123  	for i := x.Len() - 1; i >= 0; i-- {
 124  		if !equalValue(x.Get(i), y.Get(i)) {
 125  			return false
 126  		}
 127  	}
 128  	return true
 129  }
 130  
 131  // equalMap compares two maps.
 132  func equalMap(x, y Map) bool {
 133  	if x.Len() != y.Len() {
 134  		return false
 135  	}
 136  	equal := true
 137  	x.Range(func(k MapKey, vx Value) bool {
 138  		vy := y.Get(k)
 139  		equal = y.Has(k) && equalValue(vx, vy)
 140  		return equal
 141  	})
 142  	return equal
 143  }
 144  
 145  // equalUnknown compares unknown fields by direct comparison on the raw bytes
 146  // of each individual field number.
 147  func equalUnknown(x, y RawFields) bool {
 148  	if len(x) != len(y) {
 149  		return false
 150  	}
 151  	if bytes.Equal([]byte(x), []byte(y)) {
 152  		return true
 153  	}
 154  
 155  	mx := make(map[FieldNumber]RawFields)
 156  	my := make(map[FieldNumber]RawFields)
 157  	for len(x) > 0 {
 158  		fnum, _, n := protowire.ConsumeField(x)
 159  		mx[fnum] = append(mx[fnum], x[:n]...)
 160  		x = x[n:]
 161  	}
 162  	for len(y) > 0 {
 163  		fnum, _, n := protowire.ConsumeField(y)
 164  		my[fnum] = append(my[fnum], y[:n]...)
 165  		y = y[n:]
 166  	}
 167  	return reflect.DeepEqual(mx, my)
 168  }
 169