Skip to content

Commit

Permalink
assert.Equal supports T.Equal(T) method along IsEqual
Browse files Browse the repository at this point in the history
  • Loading branch information
adamluzsi committed Jun 10, 2023
1 parent 4c49fb4 commit 0db1ad8
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 32 deletions.
60 changes: 37 additions & 23 deletions assert/Asserter.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ func (a Asserter) NotPanic(blk func(), msg ...any) {
}

// Equal allows you to match if two entity is equal.
// if entities are implementing IsEqual function, then it will be used to check equality between each other.
// - IsEqual(oth T) bool
// - IsEqual(oth T) (bool, error)
//
// if entities are implementing IsEqual/Equal function, then it will be used to check equality between each other.
// - value.IsEqual(oth T) bool
// - value.IsEqual(oth T) (bool, error)
// - value.Equal(oth T) bool
// - value.Equal(oth T) (bool, error)
func (a Asserter) Equal(expected, actually any, msg ...any) {
a.TB.Helper()
const method = "Equal"
Expand Down Expand Up @@ -250,6 +253,8 @@ func (a Asserter) eq(exp, act any) bool {
return reflect.DeepEqual(exp, act)
}

var methodNamesForIsEqual = []string{"IsEqual", "Equal"}

func (a Asserter) tryIsEqual(exp, act any) (isEqual bool, ok bool) {
a.TB.Helper()
defer func() { recover() }()
Expand All @@ -260,32 +265,41 @@ func (a Asserter) tryIsEqual(exp, act any) (isEqual bool, ok bool) {
return false, false
}

method := expRV.MethodByName("IsEqual")
methodType := method.Type()
tryMethodName := func(methodName string) (bool, bool) {
method := expRV.MethodByName(methodName)
methodType := method.Type()

if methodType.NumIn() != 1 {
return false, false
}
if numOut := methodType.NumOut(); !(numOut == 1 || numOut == 2) {
return false, false
}
if methodType.In(0) != actRV.Type() {
return false, false
}
if methodType.NumIn() != 1 {
return false, false
}
if numOut := methodType.NumOut(); !(numOut == 1 || numOut == 2) {
return false, false
}
if methodType.In(0) != actRV.Type() {
return false, false
}

res := method.Call([]reflect.Value{actRV})
res := method.Call([]reflect.Value{actRV})

switch {
case methodType.NumOut() == 1: // IsEqual(T) (bool)
return res[0].Bool(), true
switch {
case methodType.NumOut() == 1: // IsEqual(T) (bool)
return res[0].Bool(), true

case methodType.NumOut() == 2: // IsEqual(T) (bool, error)
Must(a.TB).Nil(res[1].Interface())
return res[0].Bool(), true
case methodType.NumOut() == 2: // IsEqual(T) (bool, error)
Must(a.TB).Nil(res[1].Interface())
return res[0].Bool(), true

default:
return false, false
default:
return false, false
}
}

for _, methodName := range methodNamesForIsEqual {
if eq, ok := tryMethodName(methodName); ok {
return eq, ok
}
}
return false, false
}

func (a Asserter) Contain(haystack, needle any, msg ...any) {
Expand Down
8 changes: 4 additions & 4 deletions assert/Asserter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,23 @@ func TestAsserter_Equal(t *testing.T) {
},
{
Desc: "when value implements equalable and the two value is equal by IsEqual",
Expected: ExampleEqualable{
Expected: ExampleEqualableWithIsEqual{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
},
Actual: ExampleEqualable{
Actual: ExampleEqualableWithIsEqual{
relevantUnexportedValue: 42,
IrrelevantExportedField: 24,
},
IsFailed: false,
},
{
Desc: "when value implements equalable and the two value is not equal by IsEqual",
Expected: ExampleEqualable{
Expected: ExampleEqualableWithIsEqual{
relevantUnexportedValue: 24,
IrrelevantExportedField: 42,
},
Actual: ExampleEqualable{
Actual: ExampleEqualableWithIsEqual{
relevantUnexportedValue: 42,
IrrelevantExportedField: 42,
},
Expand Down
35 changes: 30 additions & 5 deletions assert/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,24 +266,49 @@ func ExampleAsserter_ErrorIs() {
assert.Must(tb).ErrorIs(errors.New("boom"), fmt.Errorf("wrapped error: %w", actualErr)) // passes for wrapped errors
}

type ExampleEqualable struct {
type ExampleEqualableWithIsEqual struct {
IrrelevantExportedField int
relevantUnexportedValue int
}

func (es ExampleEqualable) IsEqual(oth ExampleEqualable) bool {
func (es ExampleEqualableWithIsEqual) IsEqual(oth ExampleEqualableWithIsEqual) bool {
return es.relevantUnexportedValue == oth.relevantUnexportedValue
}

func ExampleAsserter_Equal_isEqualFunctionUsedForComparison() {
type ExampleEqualableWithEqual struct {
IrrelevantExportedField int
relevantUnexportedValue int
}

func (es ExampleEqualableWithEqual) IsEqual(oth ExampleEqualableWithEqual) bool {
return es.relevantUnexportedValue == oth.relevantUnexportedValue
}

func ExampleAsserter_Equal_withIsEqualMethod() {
var tb testing.TB

expected := ExampleEqualableWithIsEqual{
IrrelevantExportedField: 42,
relevantUnexportedValue: 24,
}

actual := ExampleEqualableWithIsEqual{
IrrelevantExportedField: 4242,
relevantUnexportedValue: 24,
}

assert.Must(tb).Equal(expected, actual) // passes as by IsEqual terms the two value is equal
}

func ExampleAsserter_Equal_withEqualMethod() {
var tb testing.TB

expected := ExampleEqualable{
expected := ExampleEqualableWithEqual{
IrrelevantExportedField: 42,
relevantUnexportedValue: 24,
}

actual := ExampleEqualable{
actual := ExampleEqualableWithEqual{
IrrelevantExportedField: 4242,
relevantUnexportedValue: 24,
}
Expand Down

0 comments on commit 0db1ad8

Please sign in to comment.