Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions pkg/pipe/swarm/swarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package swarm
import (
"context"
"errors"
"log/slog"
"sync"
"sync/atomic"
"testing"
Expand All @@ -15,6 +16,8 @@ import (
"github.com/stretchr/testify/require"

"go.opentelemetry.io/obi/pkg/internal/testutil"
"go.opentelemetry.io/obi/pkg/pipe/msg"
"go.opentelemetry.io/obi/pkg/pipe/swarm/swarms"
)

func TestSwarm_BuildWithError(t *testing.T) {
Expand Down Expand Up @@ -191,6 +194,63 @@ func TestSwarm_CancelTimeout_DontExit(t *testing.T) {
assert.Contains(t, cerr.runningIDs, "zombieRunner")
}

func TestSwarm_MutuallyExclusiveNodes(t *testing.T) {
// in the test graph, two parallel channels share an input channel and an output channel,
// but only one of the paths are enabled (halver or doubler, according to the "enabledNode" value)
// This test just checks that the messages flow accordingly and no channel is blocked
testGraph := func(ctx context.Context, enabledNode string) <-chan int {
inst := Instancer{}
inQueue := msg.NewQueue[int](msg.ChannelBufferLen(1), msg.Name("inQueue"))
outQueue := msg.NewQueue[int](msg.ChannelBufferLen(1), msg.Name("outQueue"))
inst.Add(func(_ context.Context) (RunFunc, error) {
if enabledNode != "doubler" {
return EmptyRunFunc()
}
in := inQueue.Subscribe()
return func(ctx context.Context) {
swarms.ForEachInput(ctx, in, slog.Debug, func(i int) {
outQueue.SendCtx(ctx, i*2)
})
}, nil
}, WithID("doubler"))
inst.Add(func(_ context.Context) (RunFunc, error) {
if enabledNode != "halver" {
return EmptyRunFunc()
}
in := inQueue.Subscribe()
return func(ctx context.Context) {
swarms.ForEachInput(ctx, in, slog.Debug, func(i int) {
outQueue.SendCtx(ctx, i/2)
})
}, nil
}, WithID("halver"))
outCh := outQueue.Subscribe(msg.SubscriberName("outputReader"))
runner, err := inst.Instance(ctx)
require.NoError(t, err)
runner.Start(ctx)
go func() {
inQueue.SendCtx(ctx, 2)
inQueue.SendCtx(ctx, 4)
inQueue.SendCtx(ctx, 6)
}()
return outCh
}

t.Run("enable doubler", func(t *testing.T) {
out := testGraph(t.Context(), "doubler")
assert.Equal(t, 4, testutil.ReadChannel(t, out, 5*time.Second))
assert.Equal(t, 8, testutil.ReadChannel(t, out, 5*time.Second))
assert.Equal(t, 12, testutil.ReadChannel(t, out, 5*time.Second))
})

t.Run("enable halver", func(t *testing.T) {
out := testGraph(t.Context(), "halver")
assert.Equal(t, 1, testutil.ReadChannel(t, out, 5*time.Second))
assert.Equal(t, 2, testutil.ReadChannel(t, out, 5*time.Second))
assert.Equal(t, 3, testutil.ReadChannel(t, out, 5*time.Second))
})
}

func assertDone(t *testing.T, s *Runner) {
timeout := time.After(5 * time.Second)
select {
Expand Down
Loading