diff --git a/pkg/pipe/swarm/swarm_test.go b/pkg/pipe/swarm/swarm_test.go index 76d857931..839c2230c 100644 --- a/pkg/pipe/swarm/swarm_test.go +++ b/pkg/pipe/swarm/swarm_test.go @@ -6,6 +6,7 @@ package swarm import ( "context" "errors" + "log/slog" "sync" "sync/atomic" "testing" @@ -15,6 +16,8 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/obi/pkg/internal/testutil" + "go.opentelemetry.io/obi/pkg/pipe/msg" + "go.opentelemetry.io/obi/pkg/pipe/swarm/swarms" ) func TestSwarm_BuildWithError(t *testing.T) { @@ -191,6 +194,63 @@ func TestSwarm_CancelTimeout_DontExit(t *testing.T) { assert.Contains(t, cerr.runningIDs, "zombieRunner") } +func TestSwarm_MutuallyExclusiveNodes(t *testing.T) { + // in the test graph, two parallel channels share an input channel and an output channel, + // but only one of the paths are enabled (halver or doubler, according to the "enabledNode" value) + // This test just checks that the messages flow accordingly and no channel is blocked + testGraph := func(ctx context.Context, enabledNode string) <-chan int { + inst := Instancer{} + inQueue := msg.NewQueue[int](msg.ChannelBufferLen(1), msg.Name("inQueue")) + outQueue := msg.NewQueue[int](msg.ChannelBufferLen(1), msg.Name("outQueue")) + inst.Add(func(_ context.Context) (RunFunc, error) { + if enabledNode != "doubler" { + return EmptyRunFunc() + } + in := inQueue.Subscribe() + return func(ctx context.Context) { + swarms.ForEachInput(ctx, in, slog.Debug, func(i int) { + outQueue.SendCtx(ctx, i*2) + }) + }, nil + }, WithID("doubler")) + inst.Add(func(_ context.Context) (RunFunc, error) { + if enabledNode != "halver" { + return EmptyRunFunc() + } + in := inQueue.Subscribe() + return func(ctx context.Context) { + swarms.ForEachInput(ctx, in, slog.Debug, func(i int) { + outQueue.SendCtx(ctx, i/2) + }) + }, nil + }, WithID("halver")) + outCh := outQueue.Subscribe(msg.SubscriberName("outputReader")) + runner, err := inst.Instance(ctx) + require.NoError(t, err) + runner.Start(ctx) + go func() { + inQueue.SendCtx(ctx, 2) + inQueue.SendCtx(ctx, 4) + inQueue.SendCtx(ctx, 6) + }() + return outCh + } + + t.Run("enable doubler", func(t *testing.T) { + out := testGraph(t.Context(), "doubler") + assert.Equal(t, 4, testutil.ReadChannel(t, out, 5*time.Second)) + assert.Equal(t, 8, testutil.ReadChannel(t, out, 5*time.Second)) + assert.Equal(t, 12, testutil.ReadChannel(t, out, 5*time.Second)) + }) + + t.Run("enable halver", func(t *testing.T) { + out := testGraph(t.Context(), "halver") + assert.Equal(t, 1, testutil.ReadChannel(t, out, 5*time.Second)) + assert.Equal(t, 2, testutil.ReadChannel(t, out, 5*time.Second)) + assert.Equal(t, 3, testutil.ReadChannel(t, out, 5*time.Second)) + }) +} + func assertDone(t *testing.T, s *Runner) { timeout := time.After(5 * time.Second) select {