diff --git a/go/core/flow.go b/go/core/flow.go index 2ca42be658..b5311bbbf3 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -48,7 +48,9 @@ type flowContext struct { // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { return (*Flow[In, Out, struct{}])(DefineAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { - fc := &flowContext{} + fc := &flowContext{ + flowName: name, + } ctx = flowContextKey.NewContext(ctx, fc) return fn(ctx, input) })) @@ -65,7 +67,9 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // Otherwise, it should ignore the callback and just return a result. func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] { return (*Flow[In, Out, Stream])(DefineStreamingAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { - fc := &flowContext{} + fc := &flowContext{ + flowName: name, + } ctx = flowContextKey.NewContext(ctx, fc) return fn(ctx, input, cb) })) diff --git a/go/core/flow_test.go b/go/core/flow_test.go index b3d912d169..3a6a797bce 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -66,3 +66,26 @@ func TestRunFlow(t *testing.T) { t.Errorf("got %d, want %d", got, want) } } + +func TestFlowNameFromContext(t *testing.T) { + r := registry.New() + flows := []*Flow[struct{}, string, struct{}]{ + DefineFlow(r, "DefineFlow", func(ctx context.Context, _ struct{}) (string, error) { + return FlowNameFromContext(ctx), nil + }), + DefineStreamingFlow(r, "DefineStreamingFlow", func(ctx context.Context, _ struct{}, s StreamCallback[struct{}]) (string, error) { + return FlowNameFromContext(ctx), nil + }), + } + for _, flow := range flows { + t.Run(flow.Name(), func(t *testing.T) { + got, err := flow.Run(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + if want := flow.Name(); got != want { + t.Errorf("got '%s', want '%s'", got, want) + } + }) + } +} \ No newline at end of file