Skip to content

Commit

Permalink
feat: receiver matcher accepting (POINTER, MATCHER), includes unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: thediveo <[email protected]>
  • Loading branch information
thediveo authored and onsi committed Apr 18, 2024
1 parent 9999deb commit ec1f186
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 28 deletions.
15 changes: 8 additions & 7 deletions matchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,21 @@ func BeClosed() types.GomegaMatcher {
//
// will repeatedly attempt to pull values out of `c` until a value matching "bar" is received.
//
// Finally, if you want to have a reference to the value *sent* to the channel you can pass the `Receive` matcher a pointer to a variable of the appropriate type:
// Furthermore, if you want to have a reference to the value *sent* to the channel you can pass the `Receive` matcher a pointer to a variable of the appropriate type:
//
// var myThing thing
// Eventually(thingChan).Should(Receive(&myThing))
// Expect(myThing.Sprocket).Should(Equal("foo"))
// Expect(myThing.IsValid()).Should(BeTrue())
//
// Finally, if you want to match the received object as well as get the actual received value into a variable, so you can reason further about the value received,
// you can pass a pointer to a variable of the approriate type first, and second a matcher:
//
// var myThing thing
// Eventually(thingChan).Should(Receive(&myThing, ContainSubstring("bar")))
func Receive(args ...interface{}) types.GomegaMatcher {
var arg interface{}
if len(args) > 0 {
arg = args[0]
}

return &matchers.ReceiveMatcher{
Arg: arg,
Args: args,
}
}

Expand Down
70 changes: 53 additions & 17 deletions matchers/receive_matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
package matchers

import (
"errors"
"fmt"
"reflect"

"github.com/onsi/gomega/format"
)

type ReceiveMatcher struct {
Arg interface{}
Args []interface{}
receivedValue reflect.Value
channelClosed bool
}
Expand All @@ -29,15 +30,38 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro

var subMatcher omegaMatcher
var hasSubMatcher bool

if matcher.Arg != nil {
subMatcher, hasSubMatcher = (matcher.Arg).(omegaMatcher)
var resultReference interface{}

// Valid arg formats are as follows, always with optional POINTER before
// optional MATCHER:
// - Receive()
// - Receive(POINTER)
// - Receive(MATCHER)
// - Receive(POINTER, MATCHER)
args := matcher.Args
if len(args) > 0 {
arg := args[0]
_, isSubMatcher := arg.(omegaMatcher)
if !isSubMatcher && reflect.ValueOf(arg).Kind() == reflect.Ptr {
// Consume optional POINTER arg first, if it ain't no matcher ;)
resultReference = arg
args = args[1:]
}
}
if len(args) > 0 {
arg := args[0]
subMatcher, hasSubMatcher = arg.(omegaMatcher)
if !hasSubMatcher {
argType := reflect.TypeOf(matcher.Arg)
if argType.Kind() != reflect.Ptr {
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nTo:\n%s\nYou need to pass a pointer!", format.Object(actual, 1), format.Object(matcher.Arg, 1))
}
// At this point we assume the dev user wanted to assign a received
// value, so [POINTER,]MATCHER.
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nTo:\n%s\nYou need to pass a pointer!", format.Object(actual, 1), format.Object(arg, 1))
}
// Consume optional MATCHER arg.
args = args[1:]
}
if len(args) > 0 {
// If there are still args present, reject all.
return false, errors.New("Receive matcher expects at most an optional pointer and/or an optional matcher")
}

winnerIndex, value, open := reflect.Select([]reflect.SelectCase{
Expand All @@ -58,16 +82,20 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro
}

if hasSubMatcher {
if didReceive {
matcher.receivedValue = value
return subMatcher.Match(matcher.receivedValue.Interface())
if !didReceive {
return false, nil
}
return false, nil
matcher.receivedValue = value
if match, err := subMatcher.Match(matcher.receivedValue.Interface()); err != nil || !match {
return match, err
}
// if we received a match, then fall through in order to handle an
// optional assignment of the received value to the specified reference.
}

if didReceive {
if matcher.Arg != nil {
outValue := reflect.ValueOf(matcher.Arg)
if resultReference != nil {
outValue := reflect.ValueOf(resultReference)

if value.Type().AssignableTo(outValue.Elem().Type()) {
outValue.Elem().Set(value)
Expand All @@ -77,7 +105,7 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro
outValue.Elem().Set(value.Elem())
return true, nil
} else {
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nType:\n%s\nTo:\n%s", format.Object(actual, 1), format.Object(value.Interface(), 1), format.Object(matcher.Arg, 1))
return false, fmt.Errorf("Cannot assign a value from the channel:\n%s\nType:\n%s\nTo:\n%s", format.Object(actual, 1), format.Object(value.Interface(), 1), format.Object(resultReference, 1))
}

}
Expand All @@ -88,7 +116,11 @@ func (matcher *ReceiveMatcher) Match(actual interface{}) (success bool, err erro
}

func (matcher *ReceiveMatcher) FailureMessage(actual interface{}) (message string) {
subMatcher, hasSubMatcher := (matcher.Arg).(omegaMatcher)
var matcherArg interface{}
if len(matcher.Args) > 0 {
matcherArg = matcher.Args[len(matcher.Args)-1]
}
subMatcher, hasSubMatcher := (matcherArg).(omegaMatcher)

closedAddendum := ""
if matcher.channelClosed {
Expand All @@ -105,7 +137,11 @@ func (matcher *ReceiveMatcher) FailureMessage(actual interface{}) (message strin
}

func (matcher *ReceiveMatcher) NegatedFailureMessage(actual interface{}) (message string) {
subMatcher, hasSubMatcher := (matcher.Arg).(omegaMatcher)
var matcherArg interface{}
if len(matcher.Args) > 0 {
matcherArg = matcher.Args[len(matcher.Args)-1]
}
subMatcher, hasSubMatcher := (matcherArg).(omegaMatcher)

closedAddendum := ""
if matcher.channelClosed {
Expand Down
53 changes: 49 additions & 4 deletions matchers/receive_matcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,39 @@ var _ = Describe("ReceiveMatcher", func() {
})
})

Context("with too many arguments", func() {
It("should error", func() {
channel := make(chan bool, 1)
var actual bool

channel <- true

success, err := (&ReceiveMatcher{Args: []interface{}{
&actual,
Equal(true),
42,
}}).Match(channel)
Expect(success).To(BeFalse())
Expect(err).To(HaveOccurred())
})
})

Context("with swapped arguments", func() {
It("should error", func() {
channel := make(chan bool, 1)
var actual bool

channel <- true

success, err := (&ReceiveMatcher{Args: []interface{}{
Equal(true),
&actual,
}}).Match(channel)
Expect(success).To(BeFalse())
Expect(err).To(HaveOccurred())
})
})

Context("with a pointer argument", func() {
Context("of the correct type", func() {
When("the channel has an interface type", func() {
Expand Down Expand Up @@ -134,12 +167,12 @@ var _ = Describe("ReceiveMatcher", func() {

var incorrectType bool

success, err := (&ReceiveMatcher{Arg: &incorrectType}).Match(channel)
success, err := (&ReceiveMatcher{Args: []interface{}{&incorrectType}}).Match(channel)
Expect(success).Should(BeFalse())
Expect(err).Should(HaveOccurred())

var notAPointer int
success, err = (&ReceiveMatcher{Arg: notAPointer}).Match(channel)
success, err = (&ReceiveMatcher{Args: []interface{}{notAPointer}}).Match(channel)
Expect(success).Should(BeFalse())
Expect(err).Should(HaveOccurred())
})
Expand Down Expand Up @@ -192,7 +225,7 @@ var _ = Describe("ReceiveMatcher", func() {
It("should error", func() {
channel := make(chan int, 1)
channel <- 3
success, err := (&ReceiveMatcher{Arg: ContainSubstring("three")}).Match(channel)
success, err := (&ReceiveMatcher{Args: []interface{}{ContainSubstring("three")}}).Match(channel)
Expect(success).Should(BeFalse())
Expect(err).Should(HaveOccurred())
})
Expand All @@ -201,13 +234,25 @@ var _ = Describe("ReceiveMatcher", func() {
Context("if nothing is received", func() {
It("should fail", func() {
channel := make(chan int, 1)
success, err := (&ReceiveMatcher{Arg: Equal(1)}).Match(channel)
success, err := (&ReceiveMatcher{Args: []interface{}{Equal(1)}}).Match(channel)
Expect(success).Should(BeFalse())
Expect(err).ShouldNot(HaveOccurred())
})
})
})

Context("with a pointer and a matcher argument", func() {
It("should succeed", func() {
channel := make(chan bool, 1)
channel <- true

var received bool

Expect(channel).Should(Receive(&received, Equal(true)))
Expect(received).Should(BeTrue())
})
})

Context("When actual is a *closed* channel", func() {
Context("for a buffered channel", func() {
It("should work until it hits the end of the buffer", func() {
Expand Down

0 comments on commit ec1f186

Please sign in to comment.