From b01afa12dc648ab361f92704c05e02976f00e305 Mon Sep 17 00:00:00 2001 From: Adam Luzsi Date: Fri, 27 Sep 2024 21:09:23 +0200 Subject: [PATCH] add syntax sugar to wait easily on the NotWithin assertion block --- assert/Asserter.go | 24 ++++++++++++++++++------ assert/Asserter_test.go | 23 +++++++++++++++++++++++ assert/example_test.go | 10 ++++++++++ assert/pkgfunc.go | 6 ++++-- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/assert/Asserter.go b/assert/Asserter.go index f80cbcb..e472fbc 100644 --- a/assert/Asserter.go +++ b/assert/Asserter.go @@ -8,6 +8,7 @@ import ( "reflect" "regexp" "strings" + "sync" "sync/atomic" "testing" "time" @@ -972,9 +973,10 @@ func (a Asserter) ReadAll(r io.Reader, msg ...Message) []byte { return bs } -func (a Asserter) Within(timeout time.Duration, blk func(context.Context), msg ...Message) { +func (a Asserter) Within(timeout time.Duration, blk func(context.Context), msg ...Message) *Async { a.TB.Helper() - if !a.within(timeout, blk) { + async, ok := a.within(timeout, blk) + if !ok { a.failWith(fmterror.Message{ Method: "Within", Cause: "Expected to finish within the timeout duration.", @@ -987,11 +989,17 @@ func (a Asserter) Within(timeout time.Duration, blk func(context.Context), msg . }, }) } + return async } -func (a Asserter) NotWithin(timeout time.Duration, blk func(context.Context), msg ...Message) { +type Async struct{ wg sync.WaitGroup } + +func (a *Async) Wait() { a.wg.Wait() } + +func (a Asserter) NotWithin(timeout time.Duration, blk func(context.Context), msg ...Message) *Async { a.TB.Helper() - if a.within(timeout, blk) { + async, ok := a.within(timeout, blk) + if ok { a.failWith(fmterror.Message{ Method: "NotWithin", Cause: `Expected to not finish within the timeout duration.`, @@ -1004,14 +1012,18 @@ func (a Asserter) NotWithin(timeout time.Duration, blk func(context.Context), ms }, }) } + return async } -func (a Asserter) within(timeout time.Duration, blk func(context.Context)) bool { +func (a Asserter) within(timeout time.Duration, blk func(context.Context)) (*Async, bool) { a.TB.Helper() + var async Async ctx, cancel := context.WithCancel(context.Background()) defer cancel() var done, isFailNow uint32 + async.wg.Add(1) go func() { + defer async.wg.Done() ro := sandbox.Run(func() { blk(ctx) atomic.AddUint32(&done, 1) @@ -1026,7 +1038,7 @@ func (a Asserter) within(timeout time.Duration, blk func(context.Context)) bool if atomic.LoadUint32(&isFailNow) != 0 { a.TB.FailNow() } - return atomic.LoadUint32(&done) == 1 + return &async, atomic.LoadUint32(&done) == 1 } func (a Asserter) Eventually(durationOrCount any, blk func(it It)) { diff --git a/assert/Asserter_test.go b/assert/Asserter_test.go index 05bb530..25ba177 100644 --- a/assert/Asserter_test.go +++ b/assert/Asserter_test.go @@ -2308,3 +2308,26 @@ type SampleStruct struct { Bar int Baz bool } + +func TestAsserter_NotWithin_join(t *testing.T) { + var done = make(chan struct{}) + nw := assert.NotWithin(t, time.Nanosecond, func(context.Context) { + <-time.After(500 * time.Millisecond) + close(done) + }) + nw.Wait() + _, ok := <-done + assert.False(t, ok) +} + +func TestAsserter_Within_join(t *testing.T) { + var done = make(chan struct{}) + stub := &doubles.TB{} + w := assert.Should(stub).Within(time.Nanosecond, func(context.Context) { + <-time.After(500 * time.Millisecond) + close(done) + }) + w.Wait() + _, ok := <-done + assert.False(t, ok) +} diff --git a/assert/example_test.go b/assert/example_test.go index 3af9809..898d144 100644 --- a/assert/example_test.go +++ b/assert/example_test.go @@ -701,6 +701,16 @@ func ExampleNotWithin() { }) } +func ExampleNotWithin_withWait() { + var tb testing.TB + + nw := assert.NotWithin(tb, time.Nanosecond, func(context.Context) { // we intentionally don't use the context from here + time.Sleep(time.Second) // OK + }) + + nw.Wait() // will wait until the NotWithin assertion's block finish +} + func ExampleAsserter_NotWithin() { var tb testing.TB a := assert.Must(tb) diff --git a/assert/pkgfunc.go b/assert/pkgfunc.go index e2a5017..c35b9a1 100644 --- a/assert/pkgfunc.go +++ b/assert/pkgfunc.go @@ -105,11 +105,13 @@ func ReadAll(tb testing.TB, r io.Reader, msg ...Message) []byte { func Within(tb testing.TB, timeout time.Duration, blk func(context.Context), msg ...Message) { tb.Helper() Must(tb).Within(timeout, blk, msg...) + // Returning *Async here doesn’t make sense because if the assertion fails, + // FailNow will terminate the current goroutine regardless. } -func NotWithin(tb testing.TB, timeout time.Duration, blk func(context.Context), msg ...Message) { +func NotWithin(tb testing.TB, timeout time.Duration, blk func(context.Context), msg ...Message) *Async { tb.Helper() - Must(tb).NotWithin(timeout, blk, msg...) + return Must(tb).NotWithin(timeout, blk, msg...) } func MatchRegexp[T ~string | []byte](tb testing.TB, v T, expr string, msg ...Message) {