Skip to content

Commit

Permalink
assert: allow comparing time.Time
Browse files Browse the repository at this point in the history
  • Loading branch information
torkelrogstad authored and boyan-soubachov committed Feb 15, 2022
1 parent 7bcf74e commit 087b655
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
24 changes: 24 additions & 0 deletions assert/assertion_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package assert
import (
"fmt"
"reflect"
"time"
)

type CompareType int
Expand Down Expand Up @@ -30,6 +31,8 @@ var (
float64Type = reflect.TypeOf(float64(1))

stringType = reflect.TypeOf("")

timeType = reflect.TypeOf(time.Time{})
)

func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
Expand Down Expand Up @@ -299,6 +302,27 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !obj1Value.CanConvert(timeType) {
break
}

// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}

timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}

return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
}

return compareEqual, false
Expand Down
7 changes: 6 additions & 1 deletion assert/assertion_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"reflect"
"runtime"
"testing"
"time"
)

func TestCompare(t *testing.T) {
Expand All @@ -22,6 +23,7 @@ func TestCompare(t *testing.T) {
type customFloat32 float32
type customFloat64 float64
type customString string
type customTime time.Time
for _, currCase := range []struct {
less interface{}
greater interface{}
Expand Down Expand Up @@ -52,14 +54,17 @@ func TestCompare(t *testing.T) {
{less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"},
{less: float64(1.23), greater: float64(2.34), cType: "float64"},
{less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"},
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
} {
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
t.Error("object should be comparable for type " + currCase.cType)
}

if resLess != compareLess {
t.Errorf("object less should be less than greater for type " + currCase.cType)
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
currCase.less, currCase.greater)
}

resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())
Expand Down

0 comments on commit 087b655

Please sign in to comment.