Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 17 additions & 55 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -948,75 +948,56 @@ func (args Arguments) Is(objects ...interface{}) bool {
return true
}

type outputRenderer func() string

// Diff gets a string describing the differences between the arguments
// and the specified objects.
//
// Returns the diff string and number of differences found.
func (args Arguments) Diff(objects []interface{}) (string, int) {
// TODO: could return string as error and nil for No difference

var outputBuilder strings.Builder
output := "\n"
var differences int

maxArgCount := len(args)
if len(objects) > maxArgCount {
maxArgCount = len(objects)
}

outputRenderers := []outputRenderer{}

for i := 0; i < maxArgCount; i++ {
i := i
var actual, expected interface{}
var actualFmt, expectedFmt func() string
var actualFmt, expectedFmt string

if len(objects) <= i {
actual = "(Missing)"
actualFmt = func() string {
return "(Missing)"
}
actualFmt = "(Missing)"
} else {
actual = objects[i]
actualFmt = func() string {
return fmt.Sprintf("(%[1]T=%[1]v)", actual)
}
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
}

if len(args) <= i {
expected = "(Missing)"
expectedFmt = func() string {
return "(Missing)"
}
expectedFmt = "(Missing)"
} else {
expected = args[i]
expectedFmt = func() string {
return fmt.Sprintf("(%[1]T=%[1]v)", expected)
}
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
}

if matcher, ok := expected.(argumentMatcher); ok {
var matches bool
func() {
defer func() {
if r := recover(); r != nil {
actualFmt = func() string {
return fmt.Sprintf("panic in argument matcher: %v", r)
}
actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
}
}()
matches = matcher.Matches(actual)
}()
if matches {
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: PASS: %s matched by %s\n", i, actualFmt(), matcher)
})
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
} else {
differences++
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: FAIL: %s not matched by %s\n", i, actualFmt(), matcher)
})
output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
}
} else {
switch expected := expected.(type) {
Expand All @@ -1025,17 +1006,13 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
// not match
differences++
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected, reflect.TypeOf(actual).Name(), actualFmt())
})
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
}
case *IsTypeArgument:
actualT := reflect.TypeOf(actual)
if actualT != expected.t {
differences++
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected.t.Name(), actualT.Name(), actualFmt())
})
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt)
}
case *FunctionalOptionsArgument:
var name string
Expand All @@ -1046,36 +1023,26 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
const tName = "[]interface{}"
if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 {
differences++
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, tName, reflect.TypeOf(actual).Name(), actualFmt())
})
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt)
} else {
if ef, af := assertOpts(expected.values, actual); ef == "" && af == "" {
// match
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: PASS: %s == %s\n", i, tName, tName)
})
output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName)
} else {
// not match
differences++
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, af, ef)
})
output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef)
}
}

default:
if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
// match
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt(), expectedFmt())
})
output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
} else {
// not match
differences++
outputRenderers = append(outputRenderers, func() string {
return fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt(), expectedFmt())
})
output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt)
}
}
}
Expand All @@ -1086,12 +1053,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
return "No differences.", differences
}

outputBuilder.WriteString("\n")
for _, r := range outputRenderers {
outputBuilder.WriteString(r())
}

return outputBuilder.String(), differences
return output, differences
}

// Assert compares the arguments with the specified objects and fails if
Expand Down
20 changes: 20 additions & 0 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"regexp"
"runtime"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -2421,3 +2422,22 @@ type user interface {
type mockUser struct{ Mock }

func (m *mockUser) Use(c caller) { m.Called(c) }

type mutatingStringer struct {
N int
s string
}

func (m *mutatingStringer) String() string {
m.s = strconv.Itoa(m.N)
return m.s
}

func TestIssue1785ArgumentWithMutatingStringer(t *testing.T) {
m := &Mock{}
m.On("Method", &mutatingStringer{N: 2})
m.On("Method", &mutatingStringer{N: 1})
m.MethodCalled("Method", &mutatingStringer{N: 1})
m.MethodCalled("Method", &mutatingStringer{N: 2})
m.AssertExpectations(t)
}