diff --git a/command/cmdbus/bus.go b/command/cmdbus/bus.go index f034e890..d0ee3899 100644 --- a/command/cmdbus/bus.go +++ b/command/cmdbus/bus.go @@ -543,23 +543,21 @@ func (b *Bus) commandAssigned(evt event.Of[CommandAssignedData]) { timeout = timer.C } - go func() { + select { + case <-b.Context().Done(): + case <-timeout: select { case <-b.Context().Done(): - case <-timeout: - select { - case <-b.Context().Done(): - case sub.errs <- fmt.Errorf("dropping %q command: %w", cmd.Name(), ErrReceiveTimeout): - } - case sub.commands <- command.NewContext[any]( - b.Context(), - cmd, - command.WhenDone(func(ctx context.Context, cfg finish.Config) error { - return b.markDone(ctx, cmd, cfg) - }), - ): + case sub.errs <- fmt.Errorf("dropping %q command: %w", cmd.Name(), ErrReceiveTimeout): } - }() + case sub.commands <- command.NewContext[any]( + b.Context(), + cmd, + command.WhenDone(func(ctx context.Context, cfg finish.Config) error { + return b.markDone(ctx, cmd, cfg) + }), + ): + } } func (b *Bus) markDone(ctx context.Context, cmd command.Command, cfg finish.Config) error { diff --git a/command/cmdbus/bus_test.go b/command/cmdbus/bus_test.go index d73c14f0..2ea5b045 100644 --- a/command/cmdbus/bus_test.go +++ b/command/cmdbus/bus_test.go @@ -244,6 +244,64 @@ func TestAssignTimeout_0(t *testing.T) { } } +func TestReceiveTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + bus, _, _ := newBus(ctx, cmdbus.ReceiveTimeout(100*time.Millisecond)) + + commands, errs, err := bus.Subscribe(ctx, "foo-cmd") + if err != nil { + t.Fatalf("failed to subscribe: %v", err) + } + + newCmd := func() command.Command { return command.New("foo-cmd", mockPayload{}).Any() } + dispatchErrc := make(chan error) + + go func() { + for i := 0; i < 3; i++ { + if err := bus.Dispatch(context.Background(), newCmd()); err != nil { + dispatchErrc <- err + } + } + }() + + var count int +L: + for { + select { + case err, ok := <-errs: + if !ok { + errs = nil + break + } + if !errors.Is(err, cmdbus.ErrReceiveTimeout) { + t.Fatal(err) + } + case _, ok := <-commands: + if !ok { + t.Fatalf("command channel should not be closed") + } + + <-time.After(200 * time.Millisecond) + count++ + if count == 2 { + break L + } + } + } + + select { + case _, ok := <-commands: + if !ok { + t.Fatalf("command channel should not be closed") + } + count++ + t.Fatalf("command channel should only send 2 commands; got %d", count) + case <-time.After(100 * time.Millisecond): + } +} + func TestReceiveTimeout_0(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()