1 package mock
2 3 import (
4 "errors"
5 "fmt"
6 "path"
7 "reflect"
8 "regexp"
9 "runtime"
10 "strings"
11 "sync"
12 "time"
13 14 "github.com/davecgh/go-spew/spew"
15 "github.com/pmezard/go-difflib/difflib"
16 "github.com/stretchr/objx"
17 18 "github.com/stretchr/testify/assert"
19 )
20 21 // regex for GCCGO functions
22 var gccgoRE = regexp.MustCompile(`\.pN\d+_`)
23 24 // TestingT is an interface wrapper around *testing.T
25 type TestingT interface {
26 Logf(format string, args ...interface{})
27 Errorf(format string, args ...interface{})
28 FailNow()
29 }
30 31 /*
32 Call
33 */
34 35 // Call represents a method call and is used for setting expectations,
36 // as well as recording activity.
37 type Call struct {
38 Parent *Mock
39 40 // The name of the method that was or will be called.
41 Method string
42 43 // Holds the arguments of the method.
44 Arguments Arguments
45 46 // Holds the arguments that should be returned when
47 // this method is called.
48 ReturnArguments Arguments
49 50 // Holds the caller info for the On() call
51 callerInfo []string
52 53 // The number of times to return the return arguments when setting
54 // expectations. 0 means to always return the value.
55 Repeatability int
56 57 // Amount of times this call has been called
58 totalCalls int
59 60 // Call to this method can be optional
61 optional bool
62 63 // Holds a channel that will be used to block the Return until it either
64 // receives a message or is closed. nil means it returns immediately.
65 WaitFor <-chan time.Time
66 67 waitTime time.Duration
68 69 // Holds a handler used to manipulate arguments content that are passed by
70 // reference. It's useful when mocking methods such as unmarshalers or
71 // decoders.
72 RunFn func(Arguments)
73 74 // PanicMsg holds msg to be used to mock panic on the function call
75 // if the PanicMsg is set to a non nil string the function call will panic
76 // irrespective of other settings
77 PanicMsg *string
78 79 // Calls which must be satisfied before this call can be
80 requires []*Call
81 }
82 83 func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments Arguments, returnArguments Arguments) *Call {
84 return &Call{
85 Parent: parent,
86 Method: methodName,
87 Arguments: methodArguments,
88 ReturnArguments: returnArguments,
89 callerInfo: callerInfo,
90 Repeatability: 0,
91 WaitFor: nil,
92 RunFn: nil,
93 PanicMsg: nil,
94 }
95 }
96 97 func (c *Call) lock() {
98 c.Parent.mutex.Lock()
99 }
100 101 func (c *Call) unlock() {
102 c.Parent.mutex.Unlock()
103 }
104 105 // Return specifies the return arguments for the expectation.
106 //
107 // Mock.On("DoSomething").Return(errors.New("failed"))
108 func (c *Call) Return(returnArguments ...interface{}) *Call {
109 c.lock()
110 defer c.unlock()
111 112 c.ReturnArguments = returnArguments
113 114 return c
115 }
116 117 // Panic specifies if the function call should fail and the panic message
118 //
119 // Mock.On("DoSomething").Panic("test panic")
120 func (c *Call) Panic(msg string) *Call {
121 c.lock()
122 defer c.unlock()
123 124 c.PanicMsg = &msg
125 126 return c
127 }
128 129 // Once indicates that the mock should only return the value once.
130 //
131 // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
132 func (c *Call) Once() *Call {
133 return c.Times(1)
134 }
135 136 // Twice indicates that the mock should only return the value twice.
137 //
138 // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
139 func (c *Call) Twice() *Call {
140 return c.Times(2)
141 }
142 143 // Times indicates that the mock should only return the indicated number
144 // of times.
145 //
146 // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
147 func (c *Call) Times(i int) *Call {
148 c.lock()
149 defer c.unlock()
150 c.Repeatability = i
151 return c
152 }
153 154 // WaitUntil sets the channel that will block the mock's return until its closed
155 // or a message is received.
156 //
157 // Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
158 func (c *Call) WaitUntil(w <-chan time.Time) *Call {
159 c.lock()
160 defer c.unlock()
161 c.WaitFor = w
162 return c
163 }
164 165 // After sets how long to block until the call returns
166 //
167 // Mock.On("MyMethod", arg1, arg2).After(time.Second)
168 func (c *Call) After(d time.Duration) *Call {
169 c.lock()
170 defer c.unlock()
171 c.waitTime = d
172 return c
173 }
174 175 // Run sets a handler to be called before returning. It can be used when
176 // mocking a method (such as an unmarshaler) that takes a pointer to a struct and
177 // sets properties in such struct
178 //
179 // Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
180 // arg := args.Get(0).(*map[string]interface{})
181 // arg["foo"] = "bar"
182 // })
183 func (c *Call) Run(fn func(args Arguments)) *Call {
184 c.lock()
185 defer c.unlock()
186 c.RunFn = fn
187 return c
188 }
189 190 // Maybe allows the method call to be optional. Not calling an optional method
191 // will not cause an error while asserting expectations
192 func (c *Call) Maybe() *Call {
193 c.lock()
194 defer c.unlock()
195 c.optional = true
196 return c
197 }
198 199 // On chains a new expectation description onto the mocked interface. This
200 // allows syntax like.
201 //
202 // Mock.
203 // On("MyMethod", 1).Return(nil).
204 // On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
205 //
206 //go:noinline
207 func (c *Call) On(methodName string, arguments ...interface{}) *Call {
208 return c.Parent.On(methodName, arguments...)
209 }
210 211 // Unset removes all mock handlers that satisfy the call instance arguments from being
212 // called. Only supported on call instances with static input arguments.
213 //
214 // For example, the only handler remaining after the following would be "MyMethod(2, 2)":
215 //
216 // Mock.
217 // On("MyMethod", 2, 2).Return(0).
218 // On("MyMethod", 3, 3).Return(0).
219 // On("MyMethod", Anything, Anything).Return(0)
220 // Mock.On("MyMethod", 3, 3).Unset()
221 func (c *Call) Unset() *Call {
222 var unlockOnce sync.Once
223 224 for _, arg := range c.Arguments {
225 if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
226 panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
227 }
228 }
229 230 c.lock()
231 defer unlockOnce.Do(c.unlock)
232 233 foundMatchingCall := false
234 235 // in-place filter slice for calls to be removed - iterate from 0'th to last skipping unnecessary ones
236 var index int // write index
237 for _, call := range c.Parent.ExpectedCalls {
238 if call.Method == c.Method {
239 _, diffCount := call.Arguments.Diff(c.Arguments)
240 if diffCount == 0 {
241 foundMatchingCall = true
242 // Remove from ExpectedCalls - just skip it
243 continue
244 }
245 }
246 c.Parent.ExpectedCalls[index] = call
247 index++
248 }
249 // trim slice up to last copied index
250 c.Parent.ExpectedCalls = c.Parent.ExpectedCalls[:index]
251 252 if !foundMatchingCall {
253 unlockOnce.Do(c.unlock)
254 c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n",
255 callString(c.Method, c.Arguments, true),
256 )
257 }
258 259 return c
260 }
261 262 // NotBefore indicates that the mock should only be called after the referenced
263 // calls have been called as expected. The referenced calls may be from the
264 // same mock instance and/or other mock instances.
265 //
266 // Mock.On("Do").Return(nil).NotBefore(
267 // Mock.On("Init").Return(nil)
268 // )
269 func (c *Call) NotBefore(calls ...*Call) *Call {
270 c.lock()
271 defer c.unlock()
272 273 for _, call := range calls {
274 if call.Parent == nil {
275 panic("not before calls must be created with Mock.On()")
276 }
277 }
278 279 c.requires = append(c.requires, calls...)
280 return c
281 }
282 283 // InOrder defines the order in which the calls should be made
284 //
285 // For example:
286 //
287 // InOrder(
288 // Mock.On("init").Return(nil),
289 // Mock.On("Do").Return(nil),
290 // )
291 func InOrder(calls ...*Call) {
292 for i := 1; i < len(calls); i++ {
293 calls[i].NotBefore(calls[i-1])
294 }
295 }
296 297 // Mock is the workhorse used to track activity on another object.
298 // For an example of its usage, refer to the "Example Usage" section at the top
299 // of this document.
300 type Mock struct {
301 // Represents the calls that are expected of
302 // an object.
303 ExpectedCalls []*Call
304 305 // Holds the calls that were made to this mocked object.
306 Calls []Call
307 308 // test is An optional variable that holds the test struct, to be used when an
309 // invalid mock call was made.
310 test TestingT
311 312 // TestData holds any data that might be useful for testing. Testify ignores
313 // this data completely allowing you to do whatever you like with it.
314 testData objx.Map
315 316 mutex sync.Mutex
317 }
318 319 // String provides a %v format string for Mock.
320 // Note: this is used implicitly by Arguments.Diff if a Mock is passed.
321 // It exists because go's default %v formatting traverses the struct
322 // without acquiring the mutex, which is detected by go test -race.
323 func (m *Mock) String() string {
324 return fmt.Sprintf("%[1]T<%[1]p>", m)
325 }
326 327 // TestData holds any data that might be useful for testing. Testify ignores
328 // this data completely allowing you to do whatever you like with it.
329 func (m *Mock) TestData() objx.Map {
330 if m.testData == nil {
331 m.testData = make(objx.Map)
332 }
333 334 return m.testData
335 }
336 337 /*
338 Setting expectations
339 */
340 341 // Test sets the [TestingT] on which errors will be reported, otherwise errors
342 // will cause a panic.
343 // Test should not be called on an object that is going to be used in a
344 // goroutine other than the one running the test function.
345 func (m *Mock) Test(t TestingT) {
346 m.mutex.Lock()
347 defer m.mutex.Unlock()
348 m.test = t
349 }
350 351 // fail fails the current test with the given formatted format and args.
352 // In case that a test was defined, it uses the test APIs for failing a test,
353 // otherwise it uses panic.
354 func (m *Mock) fail(format string, args ...interface{}) {
355 m.mutex.Lock()
356 defer m.mutex.Unlock()
357 358 if m.test == nil {
359 panic(fmt.Sprintf(format, args...))
360 }
361 m.test.Errorf(format, args...)
362 m.test.FailNow()
363 }
364 365 // On starts a description of an expectation of the specified method
366 // being called.
367 //
368 // Mock.On("MyMethod", arg1, arg2)
369 func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
370 for _, arg := range arguments {
371 if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
372 panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
373 }
374 }
375 376 m.mutex.Lock()
377 defer m.mutex.Unlock()
378 379 c := newCall(m, methodName, assert.CallerInfo(), arguments, make([]interface{}, 0))
380 m.ExpectedCalls = append(m.ExpectedCalls, c)
381 return c
382 }
383 384 // /*
385 // Recording and responding to activity
386 // */
387 388 func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
389 var expectedCall *Call
390 391 for i, call := range m.ExpectedCalls {
392 if call.Method == method {
393 _, diffCount := call.Arguments.Diff(arguments)
394 if diffCount == 0 {
395 expectedCall = call
396 if call.Repeatability > -1 {
397 return i, call
398 }
399 }
400 }
401 }
402 403 return -1, expectedCall
404 }
405 406 type matchCandidate struct {
407 call *Call
408 mismatch string
409 diffCount int
410 }
411 412 func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool {
413 if c.call == nil {
414 return false
415 }
416 if other.call == nil {
417 return true
418 }
419 420 if c.diffCount > other.diffCount {
421 return false
422 }
423 if c.diffCount < other.diffCount {
424 return true
425 }
426 427 if c.call.Repeatability > 0 && other.call.Repeatability <= 0 {
428 return true
429 }
430 return false
431 }
432 433 func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) {
434 var bestMatch matchCandidate
435 436 for _, call := range m.expectedCalls() {
437 if call.Method == method {
438 439 errInfo, tempDiffCount := call.Arguments.Diff(arguments)
440 tempCandidate := matchCandidate{
441 call: call,
442 mismatch: errInfo,
443 diffCount: tempDiffCount,
444 }
445 if tempCandidate.isBetterMatchThan(bestMatch) {
446 bestMatch = tempCandidate
447 }
448 }
449 }
450 451 return bestMatch.call, bestMatch.mismatch
452 }
453 454 func callString(method string, arguments Arguments, includeArgumentValues bool) string {
455 var argValsString string
456 if includeArgumentValues {
457 var argVals []string
458 for argIndex, arg := range arguments {
459 if _, ok := arg.(*FunctionalOptionsArgument); ok {
460 argVals = append(argVals, fmt.Sprintf("%d: %s", argIndex, arg))
461 continue
462 }
463 argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg))
464 }
465 argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
466 }
467 468 return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
469 }
470 471 // Called tells the mock object that a method has been called, and gets an array
472 // of arguments to return. Panics if the call is unexpected (i.e. not preceded by
473 // appropriate .On .Return() calls)
474 // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
475 func (m *Mock) Called(arguments ...interface{}) Arguments {
476 // get the calling function's name
477 pc, _, _, ok := runtime.Caller(1)
478 if !ok {
479 panic("Couldn't get the caller information")
480 }
481 functionPath := runtime.FuncForPC(pc).Name()
482 // Next four lines are required to use GCCGO function naming conventions.
483 // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
484 // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
485 // With GCCGO we need to remove interface information starting from pN<dd>.
486 if gccgoRE.MatchString(functionPath) {
487 functionPath = gccgoRE.Split(functionPath, -1)[0]
488 }
489 parts := strings.Split(functionPath, ".")
490 functionName := parts[len(parts)-1]
491 return m.MethodCalled(functionName, arguments...)
492 }
493 494 // MethodCalled tells the mock object that the given method has been called, and gets
495 // an array of arguments to return. Panics if the call is unexpected (i.e. not preceded
496 // by appropriate .On .Return() calls)
497 // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
498 func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
499 m.mutex.Lock()
500 // TODO: could combine expected and closes in single loop
501 found, call := m.findExpectedCall(methodName, arguments...)
502 503 if found < 0 {
504 // expected call found, but it has already been called with repeatable times
505 if call != nil {
506 m.mutex.Unlock()
507 m.fail("\nassert: mock: The method has been called over %d times.\n\tEither do one more Mock.On(%#v).Return(...), or remove extra call.\n\tThis call was unexpected:\n\t\t%s\n\tat: %s", call.totalCalls, methodName, callString(methodName, arguments, true), assert.CallerInfo())
508 }
509 // we have to fail here - because we don't know what to do
510 // as the return arguments. This is because:
511 //
512 // a) this is a totally unexpected call to this method,
513 // b) the arguments are not what was expected, or
514 // c) the developer has forgotten to add an accompanying On...Return pair.
515 closestCall, mismatch := m.findClosestCall(methodName, arguments...)
516 m.mutex.Unlock()
517 518 if closestCall != nil {
519 m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s\nat: %s\n",
520 callString(methodName, arguments, true),
521 callString(methodName, closestCall.Arguments, true),
522 diffArguments(closestCall.Arguments, arguments),
523 strings.TrimSpace(mismatch),
524 assert.CallerInfo(),
525 )
526 } else {
527 m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(%#v).Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo())
528 }
529 }
530 531 for _, requirement := range call.requires {
532 if satisfied, _ := requirement.Parent.checkExpectation(requirement); !satisfied {
533 m.mutex.Unlock()
534 m.fail("mock: Unexpected Method Call\n-----------------------------\n\n%s\n\nMust not be called before%s:\n\n%s",
535 callString(call.Method, call.Arguments, true),
536 func() (s string) {
537 if requirement.totalCalls > 0 {
538 s = " another call of"
539 }
540 if call.Parent != requirement.Parent {
541 s += " method from another mock instance"
542 }
543 return
544 }(),
545 callString(requirement.Method, requirement.Arguments, true),
546 )
547 }
548 }
549 550 if call.Repeatability == 1 {
551 call.Repeatability = -1
552 } else if call.Repeatability > 1 {
553 call.Repeatability--
554 }
555 call.totalCalls++
556 557 // add the call
558 m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments, call.ReturnArguments))
559 m.mutex.Unlock()
560 561 // block if specified
562 if call.WaitFor != nil {
563 <-call.WaitFor
564 } else {
565 time.Sleep(call.waitTime)
566 }
567 568 m.mutex.Lock()
569 panicMsg := call.PanicMsg
570 m.mutex.Unlock()
571 if panicMsg != nil {
572 panic(*panicMsg)
573 }
574 575 m.mutex.Lock()
576 runFn := call.RunFn
577 m.mutex.Unlock()
578 579 if runFn != nil {
580 runFn(arguments)
581 }
582 583 m.mutex.Lock()
584 returnArgs := call.ReturnArguments
585 m.mutex.Unlock()
586 587 return returnArgs
588 }
589 590 /*
591 Assertions
592 */
593 594 type assertExpectationiser interface {
595 AssertExpectations(TestingT) bool
596 }
597 598 // AssertExpectationsForObjects asserts that everything specified with On and Return
599 // of the specified objects was in fact called as expected.
600 //
601 // Calls may have occurred in any order.
602 func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
603 if h, ok := t.(tHelper); ok {
604 h.Helper()
605 }
606 for _, obj := range testObjects {
607 if m, ok := obj.(*Mock); ok {
608 t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
609 obj = m
610 }
611 m := obj.(assertExpectationiser)
612 if !m.AssertExpectations(t) {
613 t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m))
614 return false
615 }
616 }
617 return true
618 }
619 620 // AssertExpectations asserts that everything specified with On and Return was
621 // in fact called as expected. Calls may have occurred in any order.
622 func (m *Mock) AssertExpectations(t TestingT) bool {
623 if s, ok := t.(interface{ Skipped() bool }); ok && s.Skipped() {
624 return true
625 }
626 if h, ok := t.(tHelper); ok {
627 h.Helper()
628 }
629 630 m.mutex.Lock()
631 defer m.mutex.Unlock()
632 var failedExpectations int
633 634 // iterate through each expectation
635 expectedCalls := m.expectedCalls()
636 for _, expectedCall := range expectedCalls {
637 satisfied, reason := m.checkExpectation(expectedCall)
638 if !satisfied {
639 failedExpectations++
640 t.Logf(reason)
641 }
642 }
643 644 if failedExpectations != 0 {
645 t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
646 }
647 648 return failedExpectations == 0
649 }
650 651 func (m *Mock) checkExpectation(call *Call) (bool, string) {
652 if !call.optional && !m.methodWasCalled(call.Method, call.Arguments) && call.totalCalls == 0 {
653 return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
654 }
655 if call.Repeatability > 0 {
656 return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
657 }
658 return true, fmt.Sprintf("PASS:\t%s(%s)", call.Method, call.Arguments.String())
659 }
660 661 // AssertNumberOfCalls asserts that the method was called expectedCalls times.
662 func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
663 if h, ok := t.(tHelper); ok {
664 h.Helper()
665 }
666 m.mutex.Lock()
667 defer m.mutex.Unlock()
668 var actualCalls int
669 for _, call := range m.calls() {
670 if call.Method == methodName {
671 actualCalls++
672 }
673 }
674 return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) of method %s does not match the actual number of calls (%d).", expectedCalls, methodName, actualCalls))
675 }
676 677 // AssertCalled asserts that the method was called.
678 // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
679 func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
680 if h, ok := t.(tHelper); ok {
681 h.Helper()
682 }
683 m.mutex.Lock()
684 defer m.mutex.Unlock()
685 if !m.methodWasCalled(methodName, arguments) {
686 var calledWithArgs []string
687 for _, call := range m.calls() {
688 calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments))
689 }
690 if len(calledWithArgs) == 0 {
691 return assert.Fail(t, "Should have called with given arguments",
692 fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments))
693 }
694 return assert.Fail(t, "Should have called with given arguments",
695 fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n %v", methodName, arguments, strings.Join(calledWithArgs, "\n")))
696 }
697 return true
698 }
699 700 // AssertNotCalled asserts that the method was not called.
701 // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
702 func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
703 if h, ok := t.(tHelper); ok {
704 h.Helper()
705 }
706 m.mutex.Lock()
707 defer m.mutex.Unlock()
708 if m.methodWasCalled(methodName, arguments) {
709 return assert.Fail(t, "Should not have called with given arguments",
710 fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments))
711 }
712 return true
713 }
714 715 // IsMethodCallable checking that the method can be called
716 // If the method was called more than `Repeatability` return false
717 func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...interface{}) bool {
718 if h, ok := t.(tHelper); ok {
719 h.Helper()
720 }
721 m.mutex.Lock()
722 defer m.mutex.Unlock()
723 724 for _, v := range m.ExpectedCalls {
725 if v.Method != methodName {
726 continue
727 }
728 if len(arguments) != len(v.Arguments) {
729 continue
730 }
731 if v.Repeatability < v.totalCalls {
732 continue
733 }
734 if isArgsEqual(v.Arguments, arguments) {
735 return true
736 }
737 }
738 return false
739 }
740 741 // isArgsEqual compares arguments
742 func isArgsEqual(expected Arguments, args []interface{}) bool {
743 if len(expected) != len(args) {
744 return false
745 }
746 for i, v := range args {
747 if !reflect.DeepEqual(expected[i], v) {
748 return false
749 }
750 }
751 return true
752 }
753 754 func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
755 for _, call := range m.calls() {
756 if call.Method == methodName {
757 758 _, differences := Arguments(expected).Diff(call.Arguments)
759 760 if differences == 0 {
761 // found the expected call
762 return true
763 }
764 765 }
766 }
767 // we didn't find the expected call
768 return false
769 }
770 771 func (m *Mock) expectedCalls() []*Call {
772 return append([]*Call{}, m.ExpectedCalls...)
773 }
774 775 func (m *Mock) calls() []Call {
776 return append([]Call{}, m.Calls...)
777 }
778 779 /*
780 Arguments
781 */
782 783 // Arguments holds an array of method arguments or return values.
784 type Arguments []interface{}
785 786 const (
787 // Anything is used in Diff and Assert when the argument being tested
788 // shouldn't be taken into consideration.
789 Anything = "mock.Anything"
790 )
791 792 // AnythingOfTypeArgument contains the type of an argument
793 // for use when type checking. Used in [Arguments.Diff] and [Arguments.Assert].
794 //
795 // Deprecated: this is an implementation detail that must not be used. Use the [AnythingOfType] constructor instead, example:
796 //
797 // m.On("Do", mock.AnythingOfType("string"))
798 //
799 // All explicit type declarations can be replaced with interface{} as is expected by [Mock.On], example:
800 //
801 // func anyString interface{} {
802 // return mock.AnythingOfType("string")
803 // }
804 type AnythingOfTypeArgument = anythingOfTypeArgument
805 806 // anythingOfTypeArgument is a string that contains the type of an argument
807 // for use when type checking. Used in Diff and Assert.
808 type anythingOfTypeArgument string
809 810 // AnythingOfType returns a special value containing the
811 // name of the type to check for. The type name will be matched against the type name returned by [reflect.Type.String].
812 //
813 // Used in Diff and Assert.
814 //
815 // For example:
816 //
817 // args.Assert(t, AnythingOfType("string"), AnythingOfType("int"))
818 func AnythingOfType(t string) AnythingOfTypeArgument {
819 return anythingOfTypeArgument(t)
820 }
821 822 // IsTypeArgument is a struct that contains the type of an argument
823 // for use when type checking. This is an alternative to [AnythingOfType].
824 // Used in [Arguments.Diff] and [Arguments.Assert].
825 type IsTypeArgument struct {
826 t reflect.Type
827 }
828 829 // IsType returns an IsTypeArgument object containing the type to check for.
830 // You can provide a zero-value of the type to check. This is an
831 // alternative to [AnythingOfType]. Used in [Arguments.Diff] and [Arguments.Assert].
832 //
833 // For example:
834 //
835 // args.Assert(t, IsType(""), IsType(0))
836 func IsType(t interface{}) *IsTypeArgument {
837 return &IsTypeArgument{t: reflect.TypeOf(t)}
838 }
839 840 // FunctionalOptionsArgument contains a list of functional options arguments
841 // expected for use when matching a list of arguments.
842 type FunctionalOptionsArgument struct {
843 values []interface{}
844 }
845 846 // String returns the string representation of FunctionalOptionsArgument
847 func (f *FunctionalOptionsArgument) String() string {
848 var name string
849 if len(f.values) > 0 {
850 name = "[]" + reflect.TypeOf(f.values[0]).String()
851 }
852 853 return strings.Replace(fmt.Sprintf("%#v", f.values), "[]interface {}", name, 1)
854 }
855 856 // FunctionalOptions returns an [FunctionalOptionsArgument] object containing
857 // the expected functional-options to check for.
858 //
859 // For example:
860 //
861 // args.Assert(t, FunctionalOptions(foo.Opt1("strValue"), foo.Opt2(613)))
862 func FunctionalOptions(values ...interface{}) *FunctionalOptionsArgument {
863 return &FunctionalOptionsArgument{
864 values: values,
865 }
866 }
867 868 // argumentMatcher performs custom argument matching, returning whether or
869 // not the argument is matched by the expectation fixture function.
870 type argumentMatcher struct {
871 // fn is a function which accepts one argument, and returns a bool.
872 fn reflect.Value
873 }
874 875 func (f argumentMatcher) Matches(argument interface{}) bool {
876 expectType := f.fn.Type().In(0)
877 expectTypeNilSupported := false
878 switch expectType.Kind() {
879 case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr:
880 expectTypeNilSupported = true
881 }
882 883 argType := reflect.TypeOf(argument)
884 var arg reflect.Value
885 if argType == nil {
886 arg = reflect.New(expectType).Elem()
887 } else {
888 arg = reflect.ValueOf(argument)
889 }
890 891 if argType == nil && !expectTypeNilSupported {
892 panic(errors.New("attempting to call matcher with nil for non-nil expected type"))
893 }
894 if argType == nil || argType.AssignableTo(expectType) {
895 result := f.fn.Call([]reflect.Value{arg})
896 return result[0].Bool()
897 }
898 return false
899 }
900 901 func (f argumentMatcher) String() string {
902 return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String())
903 }
904 905 // MatchedBy can be used to match a mock call based on only certain properties
906 // from a complex struct or some calculation. It takes a function that will be
907 // evaluated with the called argument and will return true when there's a match
908 // and false otherwise.
909 //
910 // Example:
911 //
912 // m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
913 //
914 // fn must be a function accepting a single argument (of the expected type)
915 // which returns a bool. If fn doesn't match the required signature,
916 // MatchedBy() panics.
917 func MatchedBy(fn interface{}) argumentMatcher {
918 fnType := reflect.TypeOf(fn)
919 920 if fnType.Kind() != reflect.Func {
921 panic(fmt.Sprintf("assert: arguments: %s is not a func", fn))
922 }
923 if fnType.NumIn() != 1 {
924 panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
925 }
926 if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
927 panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
928 }
929 930 return argumentMatcher{fn: reflect.ValueOf(fn)}
931 }
932 933 // Get Returns the argument at the specified index.
934 func (args Arguments) Get(index int) interface{} {
935 if index+1 > len(args) {
936 panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
937 }
938 return args[index]
939 }
940 941 // Is gets whether the objects match the arguments specified.
942 func (args Arguments) Is(objects ...interface{}) bool {
943 for i, obj := range args {
944 if obj != objects[i] {
945 return false
946 }
947 }
948 return true
949 }
950 951 // Diff gets a string describing the differences between the arguments
952 // and the specified objects.
953 //
954 // Returns the diff string and number of differences found.
955 func (args Arguments) Diff(objects []interface{}) (string, int) {
956 // TODO: could return string as error and nil for No difference
957 958 output := "\n"
959 var differences int
960 961 maxArgCount := len(args)
962 if len(objects) > maxArgCount {
963 maxArgCount = len(objects)
964 }
965 966 for i := 0; i < maxArgCount; i++ {
967 var actual, expected interface{}
968 var actualFmt, expectedFmt string
969 970 if len(objects) <= i {
971 actual = "(Missing)"
972 actualFmt = "(Missing)"
973 } else {
974 actual = objects[i]
975 actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
976 }
977 978 if len(args) <= i {
979 expected = "(Missing)"
980 expectedFmt = "(Missing)"
981 } else {
982 expected = args[i]
983 expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
984 }
985 986 if matcher, ok := expected.(argumentMatcher); ok {
987 var matches bool
988 func() {
989 defer func() {
990 if r := recover(); r != nil {
991 actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
992 }
993 }()
994 matches = matcher.Matches(actual)
995 }()
996 if matches {
997 output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
998 } else {
999 differences++
1000 output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
1001 }
1002 } else {
1003 switch expected := expected.(type) {
1004 case anythingOfTypeArgument:
1005 // type checking
1006 if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
1007 // not match
1008 differences++
1009 output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
1010 }
1011 case *IsTypeArgument:
1012 actualT := reflect.TypeOf(actual)
1013 if actualT != expected.t {
1014 differences++
1015 output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt)
1016 }
1017 case *FunctionalOptionsArgument:
1018 var name string
1019 if len(expected.values) > 0 {
1020 name = "[]" + reflect.TypeOf(expected.values[0]).String()
1021 }
1022 1023 const tName = "[]interface{}"
1024 if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 {
1025 differences++
1026 output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt)
1027 } else {
1028 if ef, af := assertOpts(expected.values, actual); ef == "" && af == "" {
1029 // match
1030 output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName)
1031 } else {
1032 // not match
1033 differences++
1034 output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef)
1035 }
1036 }
1037 1038 default:
1039 if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
1040 // match
1041 output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
1042 } else {
1043 // not match
1044 differences++
1045 output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt)
1046 }
1047 }
1048 }
1049 1050 }
1051 1052 if differences == 0 {
1053 return "No differences.", differences
1054 }
1055 1056 return output, differences
1057 }
1058 1059 // Assert compares the arguments with the specified objects and fails if
1060 // they do not exactly match.
1061 func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
1062 if h, ok := t.(tHelper); ok {
1063 h.Helper()
1064 }
1065 1066 // get the differences
1067 diff, diffCount := args.Diff(objects)
1068 1069 if diffCount == 0 {
1070 return true
1071 }
1072 1073 // there are differences... report them...
1074 t.Logf(diff)
1075 t.Errorf("%sArguments do not match.", assert.CallerInfo())
1076 1077 return false
1078 }
1079 1080 // String gets the argument at the specified index. Panics if there is no argument, or
1081 // if the argument is of the wrong type.
1082 //
1083 // If no index is provided, String() returns a complete string representation
1084 // of the arguments.
1085 func (args Arguments) String(indexOrNil ...int) string {
1086 if len(indexOrNil) == 0 {
1087 // normal String() method - return a string representation of the args
1088 var argsStr []string
1089 for _, arg := range args {
1090 argsStr = append(argsStr, fmt.Sprintf("%T", arg)) // handles nil nicely
1091 }
1092 return strings.Join(argsStr, ",")
1093 } else if len(indexOrNil) == 1 {
1094 // Index has been specified - get the argument at that index
1095 index := indexOrNil[0]
1096 var s string
1097 var ok bool
1098 if s, ok = args.Get(index).(string); !ok {
1099 panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
1100 }
1101 return s
1102 }
1103 1104 panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
1105 }
1106 1107 // Int gets the argument at the specified index. Panics if there is no argument, or
1108 // if the argument is of the wrong type.
1109 func (args Arguments) Int(index int) int {
1110 var s int
1111 var ok bool
1112 if s, ok = args.Get(index).(int); !ok {
1113 panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
1114 }
1115 return s
1116 }
1117 1118 // Error gets the argument at the specified index. Panics if there is no argument, or
1119 // if the argument is of the wrong type.
1120 func (args Arguments) Error(index int) error {
1121 obj := args.Get(index)
1122 var s error
1123 var ok bool
1124 if obj == nil {
1125 return nil
1126 }
1127 if s, ok = obj.(error); !ok {
1128 panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, obj))
1129 }
1130 return s
1131 }
1132 1133 // Bool gets the argument at the specified index. Panics if there is no argument, or
1134 // if the argument is of the wrong type.
1135 func (args Arguments) Bool(index int) bool {
1136 var s bool
1137 var ok bool
1138 if s, ok = args.Get(index).(bool); !ok {
1139 panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
1140 }
1141 return s
1142 }
1143 1144 func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
1145 t := reflect.TypeOf(v)
1146 k := t.Kind()
1147 1148 if k == reflect.Ptr {
1149 t = t.Elem()
1150 k = t.Kind()
1151 }
1152 return t, k
1153 }
1154 1155 func diffArguments(expected Arguments, actual Arguments) string {
1156 if len(expected) != len(actual) {
1157 return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual))
1158 }
1159 1160 for x := range expected {
1161 if diffString := diff(expected[x], actual[x]); diffString != "" {
1162 return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString)
1163 }
1164 }
1165 1166 return ""
1167 }
1168 1169 // diff returns a diff of both values as long as both are of the same type and
1170 // are a struct, map, slice or array. Otherwise it returns an empty string.
1171 func diff(expected interface{}, actual interface{}) string {
1172 if expected == nil || actual == nil {
1173 return ""
1174 }
1175 1176 et, ek := typeAndKind(expected)
1177 at, _ := typeAndKind(actual)
1178 1179 if et != at {
1180 return ""
1181 }
1182 1183 if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
1184 return ""
1185 }
1186 1187 e := spewConfig.Sdump(expected)
1188 a := spewConfig.Sdump(actual)
1189 1190 diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
1191 A: difflib.SplitLines(e),
1192 B: difflib.SplitLines(a),
1193 FromFile: "Expected",
1194 FromDate: "",
1195 ToFile: "Actual",
1196 ToDate: "",
1197 Context: 1,
1198 })
1199 1200 return diff
1201 }
1202 1203 var spewConfig = spew.ConfigState{
1204 Indent: " ",
1205 DisablePointerAddresses: true,
1206 DisableCapacities: true,
1207 SortKeys: true,
1208 }
1209 1210 type tHelper interface {
1211 Helper()
1212 }
1213 1214 func assertOpts(expected, actual interface{}) (expectedFmt, actualFmt string) {
1215 expectedOpts := reflect.ValueOf(expected)
1216 actualOpts := reflect.ValueOf(actual)
1217 1218 var expectedFuncs []*runtime.Func
1219 var expectedNames []string
1220 for i := 0; i < expectedOpts.Len(); i++ {
1221 f := runtimeFunc(expectedOpts.Index(i).Interface())
1222 expectedFuncs = append(expectedFuncs, f)
1223 expectedNames = append(expectedNames, funcName(f))
1224 }
1225 var actualFuncs []*runtime.Func
1226 var actualNames []string
1227 for i := 0; i < actualOpts.Len(); i++ {
1228 f := runtimeFunc(actualOpts.Index(i).Interface())
1229 actualFuncs = append(actualFuncs, f)
1230 actualNames = append(actualNames, funcName(f))
1231 }
1232 1233 if expectedOpts.Len() != actualOpts.Len() {
1234 expectedFmt = fmt.Sprintf("%v", expectedNames)
1235 actualFmt = fmt.Sprintf("%v", actualNames)
1236 return
1237 }
1238 1239 for i := 0; i < expectedOpts.Len(); i++ {
1240 if !isFuncSame(expectedFuncs[i], actualFuncs[i]) {
1241 expectedFmt = expectedNames[i]
1242 actualFmt = actualNames[i]
1243 return
1244 }
1245 1246 expectedOpt := expectedOpts.Index(i).Interface()
1247 actualOpt := actualOpts.Index(i).Interface()
1248 1249 ot := reflect.TypeOf(expectedOpt)
1250 var expectedValues []reflect.Value
1251 var actualValues []reflect.Value
1252 if ot.NumIn() == 0 {
1253 return
1254 }
1255 1256 for i := 0; i < ot.NumIn(); i++ {
1257 vt := ot.In(i).Elem()
1258 expectedValues = append(expectedValues, reflect.New(vt))
1259 actualValues = append(actualValues, reflect.New(vt))
1260 }
1261 1262 reflect.ValueOf(expectedOpt).Call(expectedValues)
1263 reflect.ValueOf(actualOpt).Call(actualValues)
1264 1265 for i := 0; i < ot.NumIn(); i++ {
1266 if expectedArg, actualArg := expectedValues[i].Interface(), actualValues[i].Interface(); !assert.ObjectsAreEqual(expectedArg, actualArg) {
1267 expectedFmt = fmt.Sprintf("%s(%T) -> %#v", expectedNames[i], expectedArg, expectedArg)
1268 actualFmt = fmt.Sprintf("%s(%T) -> %#v", expectedNames[i], actualArg, actualArg)
1269 return
1270 }
1271 }
1272 }
1273 1274 return "", ""
1275 }
1276 1277 func runtimeFunc(opt interface{}) *runtime.Func {
1278 return runtime.FuncForPC(reflect.ValueOf(opt).Pointer())
1279 }
1280 1281 func funcName(f *runtime.Func) string {
1282 name := f.Name()
1283 trimmed := strings.TrimSuffix(path.Base(name), path.Ext(name))
1284 splitted := strings.Split(trimmed, ".")
1285 1286 if len(splitted) == 0 {
1287 return trimmed
1288 }
1289 1290 return splitted[len(splitted)-1]
1291 }
1292 1293 func isFuncSame(f1, f2 *runtime.Func) bool {
1294 f1File, f1Loc := f1.FileLine(f1.Entry())
1295 f2File, f2Loc := f2.FileLine(f2.Entry())
1296 1297 return f1File == f2File && f1Loc == f2Loc
1298 }
1299