assertion_compare.go raw

   1  package assert
   2  
   3  import (
   4  	"bytes"
   5  	"fmt"
   6  	"reflect"
   7  	"time"
   8  )
   9  
  10  // Deprecated: CompareType has only ever been for internal use and has accidentally been published since v1.6.0. Do not use it.
  11  type CompareType = compareResult
  12  
  13  type compareResult int
  14  
  15  const (
  16  	compareLess compareResult = iota - 1
  17  	compareEqual
  18  	compareGreater
  19  )
  20  
  21  var (
  22  	intType   = reflect.TypeOf(int(1))
  23  	int8Type  = reflect.TypeOf(int8(1))
  24  	int16Type = reflect.TypeOf(int16(1))
  25  	int32Type = reflect.TypeOf(int32(1))
  26  	int64Type = reflect.TypeOf(int64(1))
  27  
  28  	uintType   = reflect.TypeOf(uint(1))
  29  	uint8Type  = reflect.TypeOf(uint8(1))
  30  	uint16Type = reflect.TypeOf(uint16(1))
  31  	uint32Type = reflect.TypeOf(uint32(1))
  32  	uint64Type = reflect.TypeOf(uint64(1))
  33  
  34  	uintptrType = reflect.TypeOf(uintptr(1))
  35  
  36  	float32Type = reflect.TypeOf(float32(1))
  37  	float64Type = reflect.TypeOf(float64(1))
  38  
  39  	stringType = reflect.TypeOf("")
  40  
  41  	timeType  = reflect.TypeOf(time.Time{})
  42  	bytesType = reflect.TypeOf([]byte{})
  43  )
  44  
  45  func compare(obj1, obj2 interface{}, kind reflect.Kind) (compareResult, bool) {
  46  	obj1Value := reflect.ValueOf(obj1)
  47  	obj2Value := reflect.ValueOf(obj2)
  48  
  49  	// throughout this switch we try and avoid calling .Convert() if possible,
  50  	// as this has a pretty big performance impact
  51  	switch kind {
  52  	case reflect.Int:
  53  		{
  54  			intobj1, ok := obj1.(int)
  55  			if !ok {
  56  				intobj1 = obj1Value.Convert(intType).Interface().(int)
  57  			}
  58  			intobj2, ok := obj2.(int)
  59  			if !ok {
  60  				intobj2 = obj2Value.Convert(intType).Interface().(int)
  61  			}
  62  			if intobj1 > intobj2 {
  63  				return compareGreater, true
  64  			}
  65  			if intobj1 == intobj2 {
  66  				return compareEqual, true
  67  			}
  68  			if intobj1 < intobj2 {
  69  				return compareLess, true
  70  			}
  71  		}
  72  	case reflect.Int8:
  73  		{
  74  			int8obj1, ok := obj1.(int8)
  75  			if !ok {
  76  				int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
  77  			}
  78  			int8obj2, ok := obj2.(int8)
  79  			if !ok {
  80  				int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
  81  			}
  82  			if int8obj1 > int8obj2 {
  83  				return compareGreater, true
  84  			}
  85  			if int8obj1 == int8obj2 {
  86  				return compareEqual, true
  87  			}
  88  			if int8obj1 < int8obj2 {
  89  				return compareLess, true
  90  			}
  91  		}
  92  	case reflect.Int16:
  93  		{
  94  			int16obj1, ok := obj1.(int16)
  95  			if !ok {
  96  				int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
  97  			}
  98  			int16obj2, ok := obj2.(int16)
  99  			if !ok {
 100  				int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
 101  			}
 102  			if int16obj1 > int16obj2 {
 103  				return compareGreater, true
 104  			}
 105  			if int16obj1 == int16obj2 {
 106  				return compareEqual, true
 107  			}
 108  			if int16obj1 < int16obj2 {
 109  				return compareLess, true
 110  			}
 111  		}
 112  	case reflect.Int32:
 113  		{
 114  			int32obj1, ok := obj1.(int32)
 115  			if !ok {
 116  				int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
 117  			}
 118  			int32obj2, ok := obj2.(int32)
 119  			if !ok {
 120  				int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
 121  			}
 122  			if int32obj1 > int32obj2 {
 123  				return compareGreater, true
 124  			}
 125  			if int32obj1 == int32obj2 {
 126  				return compareEqual, true
 127  			}
 128  			if int32obj1 < int32obj2 {
 129  				return compareLess, true
 130  			}
 131  		}
 132  	case reflect.Int64:
 133  		{
 134  			int64obj1, ok := obj1.(int64)
 135  			if !ok {
 136  				int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
 137  			}
 138  			int64obj2, ok := obj2.(int64)
 139  			if !ok {
 140  				int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
 141  			}
 142  			if int64obj1 > int64obj2 {
 143  				return compareGreater, true
 144  			}
 145  			if int64obj1 == int64obj2 {
 146  				return compareEqual, true
 147  			}
 148  			if int64obj1 < int64obj2 {
 149  				return compareLess, true
 150  			}
 151  		}
 152  	case reflect.Uint:
 153  		{
 154  			uintobj1, ok := obj1.(uint)
 155  			if !ok {
 156  				uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
 157  			}
 158  			uintobj2, ok := obj2.(uint)
 159  			if !ok {
 160  				uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
 161  			}
 162  			if uintobj1 > uintobj2 {
 163  				return compareGreater, true
 164  			}
 165  			if uintobj1 == uintobj2 {
 166  				return compareEqual, true
 167  			}
 168  			if uintobj1 < uintobj2 {
 169  				return compareLess, true
 170  			}
 171  		}
 172  	case reflect.Uint8:
 173  		{
 174  			uint8obj1, ok := obj1.(uint8)
 175  			if !ok {
 176  				uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
 177  			}
 178  			uint8obj2, ok := obj2.(uint8)
 179  			if !ok {
 180  				uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
 181  			}
 182  			if uint8obj1 > uint8obj2 {
 183  				return compareGreater, true
 184  			}
 185  			if uint8obj1 == uint8obj2 {
 186  				return compareEqual, true
 187  			}
 188  			if uint8obj1 < uint8obj2 {
 189  				return compareLess, true
 190  			}
 191  		}
 192  	case reflect.Uint16:
 193  		{
 194  			uint16obj1, ok := obj1.(uint16)
 195  			if !ok {
 196  				uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
 197  			}
 198  			uint16obj2, ok := obj2.(uint16)
 199  			if !ok {
 200  				uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
 201  			}
 202  			if uint16obj1 > uint16obj2 {
 203  				return compareGreater, true
 204  			}
 205  			if uint16obj1 == uint16obj2 {
 206  				return compareEqual, true
 207  			}
 208  			if uint16obj1 < uint16obj2 {
 209  				return compareLess, true
 210  			}
 211  		}
 212  	case reflect.Uint32:
 213  		{
 214  			uint32obj1, ok := obj1.(uint32)
 215  			if !ok {
 216  				uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
 217  			}
 218  			uint32obj2, ok := obj2.(uint32)
 219  			if !ok {
 220  				uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
 221  			}
 222  			if uint32obj1 > uint32obj2 {
 223  				return compareGreater, true
 224  			}
 225  			if uint32obj1 == uint32obj2 {
 226  				return compareEqual, true
 227  			}
 228  			if uint32obj1 < uint32obj2 {
 229  				return compareLess, true
 230  			}
 231  		}
 232  	case reflect.Uint64:
 233  		{
 234  			uint64obj1, ok := obj1.(uint64)
 235  			if !ok {
 236  				uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
 237  			}
 238  			uint64obj2, ok := obj2.(uint64)
 239  			if !ok {
 240  				uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
 241  			}
 242  			if uint64obj1 > uint64obj2 {
 243  				return compareGreater, true
 244  			}
 245  			if uint64obj1 == uint64obj2 {
 246  				return compareEqual, true
 247  			}
 248  			if uint64obj1 < uint64obj2 {
 249  				return compareLess, true
 250  			}
 251  		}
 252  	case reflect.Float32:
 253  		{
 254  			float32obj1, ok := obj1.(float32)
 255  			if !ok {
 256  				float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
 257  			}
 258  			float32obj2, ok := obj2.(float32)
 259  			if !ok {
 260  				float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
 261  			}
 262  			if float32obj1 > float32obj2 {
 263  				return compareGreater, true
 264  			}
 265  			if float32obj1 == float32obj2 {
 266  				return compareEqual, true
 267  			}
 268  			if float32obj1 < float32obj2 {
 269  				return compareLess, true
 270  			}
 271  		}
 272  	case reflect.Float64:
 273  		{
 274  			float64obj1, ok := obj1.(float64)
 275  			if !ok {
 276  				float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
 277  			}
 278  			float64obj2, ok := obj2.(float64)
 279  			if !ok {
 280  				float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
 281  			}
 282  			if float64obj1 > float64obj2 {
 283  				return compareGreater, true
 284  			}
 285  			if float64obj1 == float64obj2 {
 286  				return compareEqual, true
 287  			}
 288  			if float64obj1 < float64obj2 {
 289  				return compareLess, true
 290  			}
 291  		}
 292  	case reflect.String:
 293  		{
 294  			stringobj1, ok := obj1.(string)
 295  			if !ok {
 296  				stringobj1 = obj1Value.Convert(stringType).Interface().(string)
 297  			}
 298  			stringobj2, ok := obj2.(string)
 299  			if !ok {
 300  				stringobj2 = obj2Value.Convert(stringType).Interface().(string)
 301  			}
 302  			if stringobj1 > stringobj2 {
 303  				return compareGreater, true
 304  			}
 305  			if stringobj1 == stringobj2 {
 306  				return compareEqual, true
 307  			}
 308  			if stringobj1 < stringobj2 {
 309  				return compareLess, true
 310  			}
 311  		}
 312  	// Check for known struct types we can check for compare results.
 313  	case reflect.Struct:
 314  		{
 315  			// All structs enter here. We're not interested in most types.
 316  			if !obj1Value.CanConvert(timeType) {
 317  				break
 318  			}
 319  
 320  			// time.Time can be compared!
 321  			timeObj1, ok := obj1.(time.Time)
 322  			if !ok {
 323  				timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
 324  			}
 325  
 326  			timeObj2, ok := obj2.(time.Time)
 327  			if !ok {
 328  				timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
 329  			}
 330  
 331  			if timeObj1.Before(timeObj2) {
 332  				return compareLess, true
 333  			}
 334  			if timeObj1.Equal(timeObj2) {
 335  				return compareEqual, true
 336  			}
 337  			return compareGreater, true
 338  		}
 339  	case reflect.Slice:
 340  		{
 341  			// We only care about the []byte type.
 342  			if !obj1Value.CanConvert(bytesType) {
 343  				break
 344  			}
 345  
 346  			// []byte can be compared!
 347  			bytesObj1, ok := obj1.([]byte)
 348  			if !ok {
 349  				bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
 350  
 351  			}
 352  			bytesObj2, ok := obj2.([]byte)
 353  			if !ok {
 354  				bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
 355  			}
 356  
 357  			return compareResult(bytes.Compare(bytesObj1, bytesObj2)), true
 358  		}
 359  	case reflect.Uintptr:
 360  		{
 361  			uintptrObj1, ok := obj1.(uintptr)
 362  			if !ok {
 363  				uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr)
 364  			}
 365  			uintptrObj2, ok := obj2.(uintptr)
 366  			if !ok {
 367  				uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr)
 368  			}
 369  			if uintptrObj1 > uintptrObj2 {
 370  				return compareGreater, true
 371  			}
 372  			if uintptrObj1 == uintptrObj2 {
 373  				return compareEqual, true
 374  			}
 375  			if uintptrObj1 < uintptrObj2 {
 376  				return compareLess, true
 377  			}
 378  		}
 379  	}
 380  
 381  	return compareEqual, false
 382  }
 383  
 384  // Greater asserts that the first element is greater than the second
 385  //
 386  //	assert.Greater(t, 2, 1)
 387  //	assert.Greater(t, float64(2), float64(1))
 388  //	assert.Greater(t, "b", "a")
 389  func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
 390  	if h, ok := t.(tHelper); ok {
 391  		h.Helper()
 392  	}
 393  	failMessage := fmt.Sprintf("\"%v\" is not greater than \"%v\"", e1, e2)
 394  	return compareTwoValues(t, e1, e2, []compareResult{compareGreater}, failMessage, msgAndArgs...)
 395  }
 396  
 397  // GreaterOrEqual asserts that the first element is greater than or equal to the second
 398  //
 399  //	assert.GreaterOrEqual(t, 2, 1)
 400  //	assert.GreaterOrEqual(t, 2, 2)
 401  //	assert.GreaterOrEqual(t, "b", "a")
 402  //	assert.GreaterOrEqual(t, "b", "b")
 403  func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
 404  	if h, ok := t.(tHelper); ok {
 405  		h.Helper()
 406  	}
 407  	failMessage := fmt.Sprintf("\"%v\" is not greater than or equal to \"%v\"", e1, e2)
 408  	return compareTwoValues(t, e1, e2, []compareResult{compareGreater, compareEqual}, failMessage, msgAndArgs...)
 409  }
 410  
 411  // Less asserts that the first element is less than the second
 412  //
 413  //	assert.Less(t, 1, 2)
 414  //	assert.Less(t, float64(1), float64(2))
 415  //	assert.Less(t, "a", "b")
 416  func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
 417  	if h, ok := t.(tHelper); ok {
 418  		h.Helper()
 419  	}
 420  	failMessage := fmt.Sprintf("\"%v\" is not less than \"%v\"", e1, e2)
 421  	return compareTwoValues(t, e1, e2, []compareResult{compareLess}, failMessage, msgAndArgs...)
 422  }
 423  
 424  // LessOrEqual asserts that the first element is less than or equal to the second
 425  //
 426  //	assert.LessOrEqual(t, 1, 2)
 427  //	assert.LessOrEqual(t, 2, 2)
 428  //	assert.LessOrEqual(t, "a", "b")
 429  //	assert.LessOrEqual(t, "b", "b")
 430  func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
 431  	if h, ok := t.(tHelper); ok {
 432  		h.Helper()
 433  	}
 434  	failMessage := fmt.Sprintf("\"%v\" is not less than or equal to \"%v\"", e1, e2)
 435  	return compareTwoValues(t, e1, e2, []compareResult{compareLess, compareEqual}, failMessage, msgAndArgs...)
 436  }
 437  
 438  // Positive asserts that the specified element is positive
 439  //
 440  //	assert.Positive(t, 1)
 441  //	assert.Positive(t, 1.23)
 442  func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
 443  	if h, ok := t.(tHelper); ok {
 444  		h.Helper()
 445  	}
 446  	zero := reflect.Zero(reflect.TypeOf(e))
 447  	failMessage := fmt.Sprintf("\"%v\" is not positive", e)
 448  	return compareTwoValues(t, e, zero.Interface(), []compareResult{compareGreater}, failMessage, msgAndArgs...)
 449  }
 450  
 451  // Negative asserts that the specified element is negative
 452  //
 453  //	assert.Negative(t, -1)
 454  //	assert.Negative(t, -1.23)
 455  func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
 456  	if h, ok := t.(tHelper); ok {
 457  		h.Helper()
 458  	}
 459  	zero := reflect.Zero(reflect.TypeOf(e))
 460  	failMessage := fmt.Sprintf("\"%v\" is not negative", e)
 461  	return compareTwoValues(t, e, zero.Interface(), []compareResult{compareLess}, failMessage, msgAndArgs...)
 462  }
 463  
 464  func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool {
 465  	if h, ok := t.(tHelper); ok {
 466  		h.Helper()
 467  	}
 468  
 469  	e1Kind := reflect.ValueOf(e1).Kind()
 470  	e2Kind := reflect.ValueOf(e2).Kind()
 471  	if e1Kind != e2Kind {
 472  		return Fail(t, "Elements should be the same type", msgAndArgs...)
 473  	}
 474  
 475  	compareResult, isComparable := compare(e1, e2, e1Kind)
 476  	if !isComparable {
 477  		return Fail(t, fmt.Sprintf(`Can not compare type "%T"`, e1), msgAndArgs...)
 478  	}
 479  
 480  	if !containsValue(allowedComparesResults, compareResult) {
 481  		return Fail(t, failMessage, msgAndArgs...)
 482  	}
 483  
 484  	return true
 485  }
 486  
 487  func containsValue(values []compareResult, value compareResult) bool {
 488  	for _, v := range values {
 489  		if v == value {
 490  			return true
 491  		}
 492  	}
 493  
 494  	return false
 495  }
 496