diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 8e52ec96c5..97ef578631 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -20,7 +20,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in conf := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), } - evt := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + evt := &nats.MutableEvent{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} err := r.NatsPubSubByProviderID["default"].Publish(ctx, conf, []datasource.StreamEvent{evt}) if err != nil { @@ -30,7 +30,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in conf2 := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), } - evt2 := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + evt2 := &nats.MutableEvent{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, conf2, []datasource.StreamEvent{evt2}) if err != nil { diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index b9b426593c..8941ac7ac1 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -22,7 +22,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["default"] != nil { err := r.NatsPubSubByProviderID["default"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: myNatsTopic, - }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) + }, []datasource.StreamEvent{(&nats.MutableEvent{Data: []byte(payload)})}) if err != nil { return nil, err } @@ -34,7 +34,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["my-nats"] != nil { err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: defaultTopic, - }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) + }, []datasource.StreamEvent{(&nats.MutableEvent{Data: []byte(payload)})}) if err != nil { return nil, err } diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index b9d5e2f0ac..aaee9e27db 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -94,10 +94,10 @@ func TestStartSubscriptionHook(t *testing.T) { if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { return nil } - ctx.WriteEvent(&kafka.Event{ + ctx.WriteEvent((&kafka.MutableEvent{ Key: []byte("1"), Data: []byte(`{"id": 1, "__typename": "Employee"}`), - }) + })) return nil }, }, @@ -266,9 +266,9 @@ func TestStartSubscriptionHook(t *testing.T) { if employeeId != 1 { return nil } - ctx.WriteEvent(&kafka.Event{ + ctx.WriteEvent((&kafka.MutableEvent{ Data: []byte(`{"id": 1, "__typename": "Employee"}`), - }) + })) return nil }, }, @@ -510,9 +510,7 @@ func TestStartSubscriptionHook(t *testing.T) { Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - ctx.WriteEvent(&core.EngineEvent{ - Data: []byte(`{"data":{"countEmp":1000}}`), - }) + ctx.WriteEvent(core.MutableEngineEvent([]byte(`{"data":{"countEmp":1000}}`))) return nil }, }, diff --git a/router-tests/modules/stream-publish/module.go b/router-tests/modules/stream-publish/module.go index e5553058ea..ef5c24277b 100644 --- a/router-tests/modules/stream-publish/module.go +++ b/router-tests/modules/stream-publish/module.go @@ -11,7 +11,7 @@ const myModuleID = "publishModule" type PublishModule struct { Logger *zap.Logger - Callback func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + Callback func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } func (m *PublishModule) Provision(ctx *core.ModuleContext) error { @@ -21,7 +21,7 @@ func (m *PublishModule) Provision(ctx *core.ModuleContext) error { return nil } -func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { +func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { m.Logger.Info("Publish Hook has been run") if m.Callback != nil { diff --git a/router-tests/modules/stream-receive/module.go b/router-tests/modules/stream-receive/module.go index 640218ad00..51d2b22a33 100644 --- a/router-tests/modules/stream-receive/module.go +++ b/router-tests/modules/stream-receive/module.go @@ -11,7 +11,7 @@ const myModuleID = "streamReceiveModule" type StreamReceiveModule struct { Logger *zap.Logger - Callback func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + Callback func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { @@ -21,7 +21,7 @@ func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { return nil } -func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { +func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { m.Logger.Info("Stream Hook has been run") if m.Callback != nil { diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index 6fb7485dc3..ddaf982029 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "strconv" + "sync/atomic" "testing" "time" @@ -23,6 +24,53 @@ import ( func TestPublishHook(t *testing.T) { t.Parallel() + t.Run("Test Publish hook can't assert to mutable types", func(t *testing.T) { + t.Parallel() + + var taPossible atomic.Bool + taPossible.Store(true) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + for _, evt := range events.All() { + _, ok := evt.(datasource.MutableStreamEvent) + if !ok { + taPossible.Store(false) + } + } + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + assert.False(t, taPossible.Load(), "invalid type assertion was possible") + }) + }) + t.Run("Test Publish hook is called", func(t *testing.T) { t.Parallel() @@ -55,25 +103,13 @@ func TestPublishHook(t *testing.T) { }) }) - t.Run("Test Publish kafka hook allows to set headers", func(t *testing.T) { + t.Run("Test Publish hook is called with mutable event", func(t *testing.T) { t.Parallel() cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Headers["x-test"] = []byte("test") - } - - return events, nil - }, - }, + "publishModule": stream_publish.PublishModule{}, }, } @@ -89,21 +125,13 @@ func TestPublishHook(t *testing.T) { LogLevel: zapcore.InfoLevel, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, }) - require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") assert.Len(t, requestLog.All(), 1) - - records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) - require.NoError(t, err) - require.Len(t, records, 1) - header := records[0].Headers[0] - require.Equal(t, "x-test", header.Key) - require.Equal(t, []byte("test"), header.Value) }) }) @@ -114,7 +142,7 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, }, @@ -159,7 +187,7 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, }, @@ -213,7 +241,7 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, }, @@ -257,26 +285,28 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { return events, nil } employeeID := ctx.Operation().Variables().GetInt("employeeID") - newEvents := []datasource.StreamEvent{} - for _, event := range events { - evt, ok := event.(*kafka.Event) + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + newEvt, ok := event.Clone().(*kafka.MutableEvent) if !ok { continue } - if evt.Headers == nil { - evt.Headers = map[string][]byte{} + newEvt.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + if newEvt.Headers == nil { + newEvt.Headers = map[string][]byte{} } - evt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) - newEvents = append(newEvents, event) + newEvt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) + newEvents = append(newEvents, newEvt) } - return newEvents, nil + + return datasource.NewStreamEvents(newEvents), nil }, }, }, diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 21f62a9b53..a30efd23f1 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -20,7 +20,6 @@ import ( "github.com/wundergraph/cosmo/router/pkg/authentication" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -115,16 +114,15 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, @@ -200,22 +198,22 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if ctx.Authentication() == nil { return events, nil } if val, ok := ctx.Authentication().Claims()["sub"]; !ok || val != "user-2" { return events, nil } - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, @@ -356,19 +354,19 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if val, ok := ctx.Request().Header[customHeader]; !ok || val[0] != "Test" { return events, nil } - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, @@ -452,8 +450,8 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - return nil, errors.New("test error from streamevents hook") + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return datasource.NewStreamEvents(nil), errors.New("test error from streamevents hook") }, }, }, @@ -526,8 +524,8 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - return nil, errors.New("deduplicated error") + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return datasource.NewStreamEvents(nil), errors.New("deduplicated error") }, }, }, @@ -621,9 +619,9 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { count := errorCounter.Add(1) - return nil, fmt.Errorf("unique error %d", count) + return datasource.NewStreamEvents(nil), fmt.Errorf("unique error %d", count) }, }, }, @@ -764,7 +762,7 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { currentHandlers.Add(1) // wait for other handlers in the batch diff --git a/router-tests/modules/streams_hooks_combined_test.go b/router-tests/modules/streams_hooks_combined_test.go index 78639dd052..47a25b48c6 100644 --- a/router-tests/modules/streams_hooks_combined_test.go +++ b/router-tests/modules/streams_hooks_combined_test.go @@ -36,36 +36,42 @@ func TestStreamsHooksCombined(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - for _, event := range events { - evt, ok := event.(*kafka.Event) + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + newEvt, ok := event.Clone().(*kafka.MutableEvent) if !ok { continue } - - if string(evt.Headers["x-publishModule"]) == "i_was_here" { - evt.Data = []byte(`{"__typename":"Employee","id": 2,"update":{"name":"irrelevant"}}`) + if string(newEvt.Headers["x-publishModule"]) == "i_was_here" { + newEvt.SetData([]byte(`{"__typename":"Employee","id": 2,"update":{"name":"irrelevant"}}`)) } + newEvents = append(newEvents, newEvt) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { return events, nil } - for _, event := range events { - evt, ok := event.(*kafka.Event) + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + newEvt, ok := event.Clone().(*kafka.MutableEvent) if !ok { continue } - evt.Headers["x-publishModule"] = []byte("i_was_here") + if newEvt.Headers == nil { + newEvt.Headers = make(map[string][]byte) + } + newEvt.Headers["x-publishModule"] = []byte("i_was_here") + newEvents = append(newEvents, newEvt) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, diff --git a/router/core/router_config.go b/router/core/router_config.go index ba24294104..c921687f66 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -28,8 +28,8 @@ import ( type subscriptionHooks struct { onStart []func(ctx SubscriptionOnStartHandlerContext) error - onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) - onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onPublishEvents []func(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) + onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) maxConcurrentOnReceiveHooks int } diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index e3279c811d..8f5ca8490e 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -91,19 +91,35 @@ func (c *pubSubSubscriptionOnStartHookContext) WriteEvent(event datasource.Strea return true } +type MutableEngineEvent []byte + +func (e MutableEngineEvent) GetData() []byte { + return e +} + +func (e MutableEngineEvent) SetData(data []byte) { + copy(e, data) +} + +func (e MutableEngineEvent) Clone() datasource.MutableStreamEvent { + return slices.Clone(e) +} + // EngineEvent is the event used to write to the engine subscription type EngineEvent struct { - Data []byte + data MutableEngineEvent } func (e *EngineEvent) GetData() []byte { - return e.Data + return e.data } -func (e *EngineEvent) Clone() datasource.StreamEvent { - return &EngineEvent{ - Data: slices.Clone(e.Data), - } +func (e *EngineEvent) WriteCopy() datasource.MutableStreamEvent { + return e.data.Clone() +} + +func (e *EngineEvent) Clone() datasource.MutableStreamEvent { + return slices.Clone(e.data) } type engineSubscriptionOnStartHookContext struct { @@ -201,13 +217,15 @@ type StreamReceiveEventHandlerContext interface { } type StreamReceiveEventHandler interface { - // OnReceiveEvents is called each time a batch of events is received from the provider before delivering them to the - // client. So for a single batch of events received from the provider, this hook will be called one time for each - // active subscription. - // It is important to optimize the logic inside this hook to avoid performance issues. - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a - // StreamHookError. - OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + // OnReceiveEvents is called whenever a batch of events is received from a provider, + // before delivering them to clients. + // The hook will be called once for each active subscription, therefore it is adviced to + // avoid resource heavy computation or blocking tasks whenever possible. + // The events argument contains all events from a batch and is shared between + // all active subscribers of these events. + // Use events.All() to iterate through them and event.Clone() to create mutable copies, when needed. + // Returning an error will result in the subscription being closed and the error being logged. + OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } type StreamPublishEventHandlerContext interface { @@ -224,13 +242,15 @@ type StreamPublishEventHandlerContext interface { } type StreamPublishEventHandler interface { - // OnPublishEvents is called each time a batch of events is going to be sent to the provider + // OnPublishEvents is called each time a batch of events is going to be sent to a provider. + // The events argument contains all events from a batch. + // Use events.All() to iterate through them and event.Clone() to create mutable copies, when needed. // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a // StreamHookError. - OnPublishEvents(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + OnPublishEvents(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } -func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnPublishEventsFn { +func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error)) datasource.OnPublishEventsFn { if fn == nil { return nil } @@ -245,7 +265,9 @@ func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, publishEventConfiguration: pubConf, } - return fn(hookCtx, evts) + newEvts, err := fn(hookCtx, datasource.NewStreamEvents(evts)) + + return newEvts.Unsafe(), err } } @@ -277,7 +299,7 @@ func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() d return c.subscriptionEventConfiguration } -func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnReceiveEventsFn { +func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error)) datasource.OnReceiveEventsFn { if fn == nil { return nil } @@ -291,7 +313,7 @@ func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, } - - return fn(hookCtx, evts) + newEvts, err := fn(hookCtx, datasource.NewStreamEvents(evts)) + return newEvts.Unsafe(), err } } diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index 57bbb70ed7..fd02ffccf6 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -2,6 +2,8 @@ package datasource import ( "context" + "iter" + "slices" "github.com/wundergraph/cosmo/router/pkg/metric" ) @@ -46,7 +48,7 @@ type ProviderBuilder[P, E any] interface { BuildEngineDataSourceFactory(data E, providers map[string]Provider) (EngineDataSourceFactory, error) } -// ProviderType represents the type of pubsub provider +// ProviderType represents the type of pubsub provider. type ProviderType string const ( @@ -55,12 +57,44 @@ const ( ProviderTypeRedis ProviderType = "redis" ) -// StreamEvent is a generic interface for all stream events -// Each provider will have its own event type that implements this interface -// there could be other common fields in the future, but for now we only have data +// StreamEvents is a list of stream events coming from or going to event providers. +type StreamEvents struct { + evts []StreamEvent +} + +// All is an iterator, which can be used to iterate through all events. +func (e StreamEvents) All() iter.Seq2[int, StreamEvent] { + return slices.All(e.evts) +} + +// Len returns the number of events. +func (e StreamEvents) Len() int { + return len(e.evts) +} + +// Unsafe returns the underlying slice of stream events. +// This slice is not thread safe and should not be modified directly. +func (e StreamEvents) Unsafe() []StreamEvent { + return e.evts +} + +func NewStreamEvents(evts []StreamEvent) StreamEvents { + return StreamEvents{evts: evts} +} + +// A StreamEvent is a single event coming from or going to an event provider. type StreamEvent interface { + // GetData returns the payload data of the event. GetData() []byte - Clone() StreamEvent + // Clone returns a mutable copy of the event. + Clone() MutableStreamEvent +} + +// A MutableStreamEvent is a stream event that can be modified. +type MutableStreamEvent interface { + StreamEvent + // SetData sets the data of the event. + SetData([]byte) } // SubscriptionEventConfiguration is the interface that all subscription event configurations must implement diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 6ef41c56a5..939d15ad3f 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -1,9 +1,9 @@ package datasource import ( - "bytes" "context" "errors" + "slices" "testing" "github.com/stretchr/testify/assert" @@ -12,18 +12,32 @@ import ( ) // Test helper types +type mutableTestEvent []byte + +func (e mutableTestEvent) Clone() MutableStreamEvent { + var evt mutableTestEvent = make([]byte, len(e)) + copy(evt, e) + return evt +} + +func (e mutableTestEvent) GetData() []byte { + return e +} + +func (e mutableTestEvent) SetData(data []byte) { + copy(e, data) +} + type testEvent struct { - data []byte + evt mutableTestEvent } func (e *testEvent) GetData() []byte { - return e.data + return slices.Clone(e.evt.GetData()) } -func (e *testEvent) Clone() StreamEvent { - return &testEvent{ - data: bytes.Clone(e.data), - } +func (e *testEvent) Clone() MutableStreamEvent { + return e.evt.Clone() } type testSubscriptionConfig struct { @@ -158,8 +172,8 @@ func TestProvider_Publish_NoHooks_Success(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data 1")}, - &testEvent{data: []byte("test data 2")}, + &testEvent{mutableTestEvent("test data 1")}, + &testEvent{mutableTestEvent("test data 2")}, } mockAdapter.On("Publish", mock.Anything, config, events).Return(nil) @@ -181,7 +195,7 @@ func TestProvider_Publish_NoHooks_Error(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } expectedError := errors.New("publish error") @@ -205,10 +219,10 @@ func TestProvider_Publish_WithHooks_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original data")}, + &testEvent{mutableTestEvent("original data")}, } modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified data")}, + &testEvent{mutableTestEvent("modified data")}, } // Define hook that modifies events @@ -237,7 +251,7 @@ func TestProvider_Publish_WithHooks_HookError(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing error") @@ -270,10 +284,10 @@ func TestProvider_Publish_WithHooks_AdapterError(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original data")}, + &testEvent{mutableTestEvent("original data")}, } processedEvents := []StreamEvent{ - &testEvent{data: []byte("processed data")}, + &testEvent{mutableTestEvent("processed data")}, } adapterError := errors.New("adapter publish error") @@ -304,15 +318,15 @@ func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } // Chain of hooks that modify the data hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, nil } hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil } mockAdapter.On("Publish", mock.Anything, config, mock.MatchedBy(func(events []StreamEvent) bool { @@ -370,7 +384,7 @@ func TestApplyPublishEventHooks_NoHooks(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{}) @@ -387,10 +401,10 @@ func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified")}, + &testEvent{mutableTestEvent("modified")}, } hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { @@ -411,7 +425,7 @@ func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } hookError := errors.New("hook processing failed") @@ -434,17 +448,17 @@ func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("step1")}}, nil } hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("step2")}}, nil } hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("final")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) @@ -462,18 +476,18 @@ func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } middleHookError := errors.New("middle hook failed") hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("step1")}}, nil } hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { return nil, middleHookError } hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("final")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index b0ef4dbd71..1c920f7222 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -41,9 +41,8 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { for ctx, subId := range subscriptions { semaphore <- struct{}{} // Acquire a slot - eventsCopy := copyEvents(events) wg.Add(1) - go s.updateSubscription(ctx, &wg, errCh, semaphore, subId, eventsCopy) + go s.updateSubscription(ctx, &wg, errCh, semaphore, subId, events) } doneLogging := make(chan struct{}) @@ -70,38 +69,6 @@ func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { s.hooks = hooks } -// applyReceiveEventHooks processes events through a chain of hook functions -// Each hook receives the result from the previous hook, creating a proper middleware pipeline -func applyReceiveEventHooks( - ctx context.Context, - cfg SubscriptionEventConfiguration, - events []StreamEvent, - hooks []OnReceiveEventsFn) ([]StreamEvent, error) { - // Copy the events to avoid modifying the original slice - currentEvents := make([]StreamEvent, len(events)) - for i, event := range events { - currentEvents[i] = event.Clone() - } - // Apply each hook in sequence, passing the result of one as the input to the next - // If any hook returns an error, stop processing and return the error - for _, hook := range hooks { - var err error - currentEvents, err = hook(ctx, cfg, currentEvents) - if err != nil { - return currentEvents, err - } - } - return currentEvents, nil -} - -func copyEvents(in []StreamEvent) []StreamEvent { - out := make([]StreamEvent, len(in)) - for i := range in { - out[i] = in[i].Clone() - } - return out -} - func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) { defer wg.Done() defer func() { diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index d5ba1fcd90..693d9d5da0 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -3,6 +3,7 @@ package datasource import ( "context" "errors" + "sync" "testing" "time" @@ -44,8 +45,8 @@ func TestSubscriptionEventUpdater_Update_NoHooks(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data 1")}, - &testEvent{data: []byte("test data 2")}, + &testEvent{mutableTestEvent("test data 1")}, + &testEvent{mutableTestEvent("test data 2")}, } // Expect calls to Update for each event @@ -69,10 +70,10 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testin fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original data")}, + &testEvent{mutableTestEvent("original data")}, } modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified data")}, + &testEvent{mutableTestEvent("modified data")}, } // Create wrapper function for the mock @@ -116,7 +117,7 @@ func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing error") @@ -157,20 +158,20 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } // Chain of hooks that modify the data receivedArgs1 := make(chan receivedHooksArgs, 1) hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, nil } receivedArgs2 := make(chan receivedHooksArgs, 1) hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil } // Expect call to UpdateSubscription with modified data @@ -200,7 +201,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) select { case receivedArgs2 := <-receivedArgs2: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, receivedArgs2.events) + assert.Equal(t, []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, receivedArgs2.events) assert.Equal(t, config, receivedArgs2.cfg) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for events") @@ -302,179 +303,260 @@ func TestNewSubscriptionEventUpdater(t *testing.T) { assert.Equal(t, mockUpdater, concreteUpdater.eventUpdater) } -func TestApplyReceiveEventHooks_NoHooks(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_PassthroughWithNoHooks(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } - originalEvents := []StreamEvent{ - &testEvent{data: []byte("test data")}, + events := []StreamEvent{ + &testEvent{mutableTestEvent("event data 1")}, + &testEvent{mutableTestEvent("event data 2")}, + &testEvent{mutableTestEvent("event data 3")}, } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) + // With no hooks, Update should call the underlying eventUpdater.Update for each event + mockUpdater.On("Update", []byte("event data 1")).Return() + mockUpdater.On("Update", []byte("event data 2")).Return() + mockUpdater.On("Update", []byte("event data 3")).Return() - assert.NoError(t, err) - assert.Equal(t, originalEvents, result) + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, // No hooks + } + + updater.Update(events) + + // Verify all events were passed through without modification + mockUpdater.AssertCalled(t, "Update", []byte("event data 1")) + mockUpdater.AssertCalled(t, "Update", []byte("event data 2")) + mockUpdater.AssertCalled(t, "Update", []byte("event data 3")) + mockUpdater.AssertNumberOfCalls(t, "Update", 3) } -func TestApplyReceiveEventHooks_SingleHook_Success(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_WithSingleHookModification(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, - } - modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified")}, + &testEvent{mutableTestEvent("original data 1")}, + &testEvent{mutableTestEvent("original data 2")}, } + // Hook that modifies events by adding a prefix hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "modified: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } return modifiedEvents, nil } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + + // With hooks, UpdateSubscription should be called with modified data + mockUpdater.On("UpdateSubscription", subId, []byte("modified: original data 1")).Return() + mockUpdater.On("UpdateSubscription", subId, []byte("modified: original data 2")).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook}, + }, + } - assert.NoError(t, err) - assert.Equal(t, modifiedEvents, result) + updater.Update(originalEvents) + + // Verify modified events were sent to UpdateSubscription, not the original events + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("modified: original data 1")) + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("modified: original data 2")) + mockUpdater.AssertNumberOfCalls(t, "UpdateSubscription", 2) + // Update should NOT be called when hooks are present + mockUpdater.AssertNotCalled(t, "Update") } -func TestApplyReceiveEventHooks_SingleHook_Error(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionAndLogsError(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } - originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + events := []StreamEvent{ + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing failed") + // Hook that returns an error hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return nil, hookError + // Return the events but also return an error + return events, hookError } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + // Set up logger with observer to verify error logging + zCore, logObserver := observer.New(zap.InfoLevel) + logger := zap.New(zCore) + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + // Events are still sent even when hook returns error + mockUpdater.On("UpdateSubscription", subId, []byte("test data")).Return() + // Subscription should be closed due to the error + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + updater := NewSubscriptionEventUpdater(config, Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook}, + }, mockUpdater, logger) + + updater.Update(events) - assert.Error(t, err) - assert.Equal(t, hookError, err) - assert.Nil(t, result) + // Verify events were still sent despite the error + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("test data")) + // Verify subscription was closed due to the error + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) + // Update should NOT be called when hooks are present + mockUpdater.AssertNotCalled(t, "Update") + + // Verify error was logged (logging happens asynchronously) + assert.Eventually(t, func() bool { + logs := logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll() + if len(logs) != 1 { + return false + } + // Verify the logged error message contains our error + return logs[0].ContextMap()["error"] == hookError.Error() + }, time.Second, 10*time.Millisecond, "expected error to be logged") } -func TestApplyReceiveEventHooks_MultipleHooks_Success(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } + // Track what each hook receives and when it's called + hookCallOrder := make([]int, 0, 3) + var mu sync.Mutex + + // Hook 1: Adds "step1: " prefix receivedArgs1 := make(chan receivedHooksArgs, 1) hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + mu.Lock() + hookCallOrder = append(hookCallOrder, 1) + mu.Unlock() receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "step1: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } + return modifiedEvents, nil } + + // Hook 2: Adds "step2: " prefix receivedArgs2 := make(chan receivedHooksArgs, 1) hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + mu.Lock() + hookCallOrder = append(hookCallOrder, 2) + mu.Unlock() receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "step2: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } + return modifiedEvents, nil } + + // Hook 3: Adds "step3: " prefix receivedArgs3 := make(chan receivedHooksArgs, 1) hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + mu.Lock() + hookCallOrder = append(hookCallOrder, 3) + mu.Unlock() receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("final")}}, nil + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "step3: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } + return modifiedEvents, nil } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + // Final modified data should have all three transformations applied + mockUpdater.On("UpdateSubscription", subId, []byte("step3: step2: step1: original")).Return() - select { - case receivedArgs1 := <-receivedArgs1: - assert.Equal(t, originalEvents, receivedArgs1.events) - assert.Equal(t, config, receivedArgs1.cfg) - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook1, hook2, hook3}, + }, } - select { - case receivedArgs2 := <-receivedArgs2: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) - assert.Equal(t, config, receivedArgs2.cfg) - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") - } + updater.Update(originalEvents) + // Verify hook 1 received original events select { - case receivedArgs3 := <-receivedArgs3: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step2")}}, receivedArgs3.events) - assert.Equal(t, config, receivedArgs3.cfg) + case args1 := <-receivedArgs1: + assert.Equal(t, originalEvents, args1.events, "Hook 1 should receive original events") + assert.Equal(t, config, args1.cfg) + assert.Len(t, args1.events, 1) + assert.Equal(t, "original", string(args1.events[0].GetData())) case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + t.Fatal("timeout waiting for hook 1") } - assert.NoError(t, err) - assert.Len(t, result, 1) - assert.Equal(t, "final", string(result[0].GetData())) -} - -func TestApplyReceiveEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { - ctx := context.Background() - config := &testSubscriptionEventConfig{ - providerID: "test-provider", - providerType: ProviderTypeNats, - fieldName: "testField", - } - originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, - } - middleHookError := errors.New("middle hook failed") - - receivedArgs1 := make(chan receivedHooksArgs, 1) - hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil - } - receivedArgs2 := make(chan receivedHooksArgs, 1) - hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} - return nil, middleHookError - } - receivedArgs3 := make(chan receivedHooksArgs, 1) - hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("final")}}, nil - } - - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) - - assert.Error(t, err) - assert.Equal(t, middleHookError, err) - assert.Nil(t, result) - + // Verify hook 2 received events modified by hook 1 select { - case receivedArgs1 := <-receivedArgs1: - assert.Equal(t, originalEvents, receivedArgs1.events) - assert.Equal(t, config, receivedArgs1.cfg) + case args2 := <-receivedArgs2: + assert.Equal(t, config, args2.cfg) + assert.Len(t, args2.events, 1) + assert.Equal(t, "step1: original", string(args2.events[0].GetData()), "Hook 2 should receive output from hook 1") case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + t.Fatal("timeout waiting for hook 2") } + // Verify hook 3 received events modified by hook 2 select { - case receivedArgs2 := <-receivedArgs2: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) - assert.Equal(t, config, receivedArgs2.cfg) + case args3 := <-receivedArgs3: + assert.Equal(t, config, args3.cfg) + assert.Len(t, args3.events, 1) + assert.Equal(t, "step2: step1: original", string(args3.events[0].GetData()), "Hook 3 should receive output from hook 2") case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + t.Fatal("timeout waiting for hook 3") } - assert.Empty(t, receivedArgs3) + // Verify hooks were called in correct order + mu.Lock() + assert.Equal(t, []int{1, 2, 3}, hookCallOrder, "Hooks should be called in order") + mu.Unlock() + + // Verify final modified events were sent to UpdateSubscription + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("step3: step2: step1: original")) + mockUpdater.AssertNumberOfCalls(t, "UpdateSubscription", 1) + mockUpdater.AssertNotCalled(t, "Update") } // Test the updateEvents method indirectly through Update method @@ -555,7 +637,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscri fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { @@ -592,7 +674,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing error") diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index 7f61a242b9..fcd1cc0c70 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -108,11 +108,15 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u DestinationName: r.Topic, }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: r.Value, - Headers: headers, - Key: r.Key, - }}) + updater.Update([]datasource.StreamEvent{ + &Event{ + evt: &MutableEvent{ + Data: r.Value, + Headers: headers, + Key: r.Key, + }, + }, + }) } } } @@ -212,7 +216,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv var errMutex sync.Mutex for _, streamEvent := range events { - kafkaEvent, ok := streamEvent.(*Event) + kafkaEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for Kafka adapter", nil) } diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 00a38023ea..9d48fd0db0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -15,18 +15,66 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// Event represents an event from Kafka +// Event implements datasource.StreamEvent type Event struct { + evt *MutableEvent +} + +func (e *Event) GetData() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Data) +} + +func (e *Event) GetKey() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Key) +} + +func (e *Event) GetHeaders() map[string][]byte { + if e.evt == nil { + return nil + } + return cloneHeaders(e.evt.Headers) +} + +func (e Event) Clone() datasource.MutableStreamEvent { + return e.evt.Clone() +} + +func cloneHeaders(src map[string][]byte) map[string][]byte { + if src == nil { + return nil + } + dst := make(map[string][]byte, len(src)) + for k, v := range src { + dst[k] = slices.Clone(v) + } + return dst +} + +// MutableEvent implements datasource.MutableEvent +type MutableEvent struct { Key []byte `json:"key"` Data json.RawMessage `json:"data"` Headers map[string][]byte `json:"headers"` } -func (e *Event) GetData() []byte { +func (e *MutableEvent) GetData() []byte { return e.Data } -func (e *Event) Clone() datasource.StreamEvent { +func (e *MutableEvent) SetData(data []byte) { + if e == nil { + return + } + e.Data = data +} + +func (e *MutableEvent) Clone() datasource.MutableStreamEvent { e2 := *e e2.Data = slices.Clone(e.Data) e2.Headers = make(map[string][]byte, len(e.Headers)) @@ -61,10 +109,10 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { - Provider string `json:"providerId"` - Topic string `json:"topic"` - Event Event `json:"event"` - FieldName string `json:"rootFieldName"` + Provider string `json:"providerId"` + Topic string `json:"topic"` + Event MutableEvent `json:"event"` + FieldName string `json:"rootFieldName"` } // PublishEventConfiguration returns the publish event configuration from the publishData type @@ -172,7 +220,7 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&Event{&publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite @@ -192,3 +240,4 @@ func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, fil var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) var _ datasource.StreamEvent = (*Event)(nil) +var _ datasource.MutableStreamEvent = (*MutableEvent)(nil) diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory.go b/router/pkg/pubsub/kafka/engine_datasource_factory.go index d89eb408b0..b4e1356714 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory.go @@ -55,7 +55,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri evtCfg := publishData{ Provider: c.providerId, Topic: c.topics[0], - Event: Event{Data: eventData}, + Event: MutableEvent{Data: eventData}, FieldName: c.fieldName, } diff --git a/router/pkg/pubsub/kafka/engine_datasource_test.go b/router/pkg/pubsub/kafka/engine_datasource_test.go index 846203d6e0..5fb5808173 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_test.go @@ -23,9 +23,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: publishData{ - Provider: "test-provider", - Topic: "test-topic", - Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + Provider: "test-provider", + Topic: "test-topic", + Event: MutableEvent{Data: json.RawMessage(`{"message":"hello"}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, @@ -33,9 +33,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "with special characters", config: publishData{ - Provider: "test-provider-id", - Topic: "topic-with-hyphens", - Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: MutableEvent{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, @@ -43,9 +43,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "with key", config: publishData{ - Provider: "test-provider-id", - Topic: "topic-with-hyphens", - Event: Event{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: MutableEvent{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, @@ -53,9 +53,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "with headers", config: publishData{ - Provider: "test-provider-id", - Topic: "topic-with-hyphens", - Event: Event{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: MutableEvent{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index 13628db1f6..e32368c658 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -149,10 +149,12 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.Subscrip DestinationName: msg.Subject(), }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: msg.Data(), - Headers: msg.Headers(), - }}) + updater.Update([]datasource.StreamEvent{ + Event{evt: &MutableEvent{ + Data: msg.Data(), + Headers: map[string][]string(msg.Headers()), + }}, + }) // Acknowledge the message after it has been processed ackErr := msg.Ack() @@ -195,10 +197,12 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.Subscrip ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject, }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: msg.Data, - Headers: msg.Header, - }}) + updater.Update([]datasource.StreamEvent{ + Event{evt: &MutableEvent{ + Data: msg.Data, + Headers: map[string][]string(msg.Header), + }}, + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscriptions for _, subscription := range subscriptions { @@ -245,7 +249,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv log.Debug("publish", zap.Int("event_count", len(events))) for _, streamEvent := range events { - natsEvent, ok := streamEvent.(*Event) + natsEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for NATS adapter", nil) } @@ -296,7 +300,7 @@ func (p *ProviderAdapter) Request(ctx context.Context, cfg datasource.PublishEve return datasource.NewError("nats client not initialized", nil) } - natsEvent, ok := event.(*Event) + natsEvent, ok := event.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for NATS adapter", nil) } diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index 3b2014a71a..f0b8c7b57a 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -15,24 +15,70 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// Event represents an event from NATS type Event struct { + evt *MutableEvent +} + +func (e Event) GetData() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Data) +} + +func (e Event) GetHeaders() map[string][]string { + if e.evt == nil || e.evt.Headers == nil { + return nil + } + return cloneHeaders(e.evt.Headers) +} + +func (e Event) Clone() datasource.MutableStreamEvent { + return e.evt.Clone() +} + +type MutableEvent struct { Data json.RawMessage `json:"data"` Headers map[string][]string `json:"headers"` } -func (e *Event) GetData() []byte { +func (e *MutableEvent) GetData() []byte { + if e == nil { + return nil + } return e.Data } -func (e *Event) Clone() datasource.StreamEvent { - e2 := *e - e2.Data = slices.Clone(e.Data) - e2.Headers = make(map[string][]string, len(e.Headers)) - for k, v := range e.Headers { - e2.Headers[k] = slices.Clone(v) +func (e *MutableEvent) SetData(data []byte) { + if e == nil { + return } - return &e2 + e.Data = slices.Clone(data) +} + +func (e *MutableEvent) Clone() datasource.MutableStreamEvent { + if e == nil { + return nil + } + return &MutableEvent{ + Data: slices.Clone(e.Data), + Headers: cloneHeaders(e.Headers), + } +} + +func (e *MutableEvent) ToStreamEvent() datasource.StreamEvent { + return &Event{evt: e} +} + +func cloneHeaders(src map[string][]string) map[string][]string { + if src == nil { + return nil + } + dst := make(map[string][]string, len(src)) + for k, v := range src { + dst[k] = slices.Clone(v) + } + return dst } type StreamConfiguration struct { @@ -65,10 +111,10 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { - Provider string `json:"providerId"` - Subject string `json:"subject"` - Event Event `json:"event"` - FieldName string `json:"rootFieldName"` + Provider string `json:"providerId"` + Subject string `json:"subject"` + Event MutableEvent `json:"event"` + FieldName string `json:"rootFieldName"` } func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { @@ -164,7 +210,7 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *byt return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{Event{evt: &publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite @@ -197,7 +243,7 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *byt return fmt.Errorf("adapter for provider %s is not of the right type", publishData.Provider) } - return adapter.Request(ctx, publishData.PublishEventConfiguration(), &publishData.Event, out) + return adapter.Request(ctx, publishData.PublishEventConfiguration(), Event{evt: &publishData.Event}, out) } func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { @@ -208,3 +254,4 @@ func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) var _ datasource.PublishEventConfiguration = (*PublishAndRequestEventConfiguration)(nil) var _ datasource.StreamEvent = (*Event)(nil) +var _ datasource.MutableStreamEvent = (*MutableEvent)(nil) diff --git a/router/pkg/pubsub/nats/engine_datasource_factory.go b/router/pkg/pubsub/nats/engine_datasource_factory.go index d88d25b868..f4006448dd 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory.go @@ -69,7 +69,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri Provider: c.providerId, Subject: subject, FieldName: c.fieldName, - Event: Event{Data: eventData}, + Event: MutableEvent{Data: eventData}, } return evtCfg.MarshalJSONTemplate() diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go index 8665f42181..183179c083 100644 --- a/router/pkg/pubsub/nats/engine_datasource_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -25,9 +25,9 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: publishData{ - Provider: "test-provider", - Subject: "test-subject", - Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + Provider: "test-provider", + Subject: "test-subject", + Event: MutableEvent{Data: json.RawMessage(`{"message":"hello"}`)}, FieldName: "test-field", }, wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, @@ -35,9 +35,9 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "with special characters", config: publishData{ - Provider: "test-provider-id", - Subject: "subject-with-hyphens", - Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + Provider: "test-provider-id", + Subject: "subject-with-hyphens", + Event: MutableEvent{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, FieldName: "test-field", }, wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 8c65bc3413..8c056fe6c1 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -128,9 +128,11 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri ProviderType: metric.ProviderTypeRedis, DestinationName: msg.Channel, }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: []byte(msg.Payload), - }}) + updater.Update([]datasource.StreamEvent{ + Event{evt: &MutableEvent{ + Data: []byte(msg.Payload), + }}, + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscription if it is not already done log.Debug("application context done, stopping subscription") @@ -171,7 +173,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv log.Debug("publish", zap.Int("event_count", len(events))) for _, streamEvent := range events { - redisEvent, ok := streamEvent.(*Event) + redisEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for Redis adapter", nil) } diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index e796b60e66..56f00a4841 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -15,17 +15,45 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// Event represents an event from Redis type Event struct { + evt *MutableEvent +} + +func (e Event) GetData() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Data) +} + +func (e Event) Clone() datasource.MutableStreamEvent { + return e.evt.Clone() +} + +type MutableEvent struct { Data json.RawMessage `json:"data"` } -func (e *Event) GetData() []byte { +func (e *MutableEvent) GetData() []byte { + if e == nil { + return nil + } return e.Data } -func (e *Event) Clone() datasource.StreamEvent { - return &Event{ +func (e *MutableEvent) SetData(data []byte) { + if e == nil { + return + } + e.Data = data +} + +func (e *MutableEvent) Clone() datasource.MutableStreamEvent { + if e == nil { + return nil + } + + return &MutableEvent{ Data: slices.Clone(e.Data), } } @@ -55,10 +83,10 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { - Provider string `json:"providerId"` - Channel string `json:"channel"` - Event Event `json:"event"` - FieldName string `json:"rootFieldName"` + Provider string `json:"providerId"` + Channel string `json:"channel"` + Event MutableEvent `json:"event"` + FieldName string `json:"rootFieldName"` } func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { @@ -162,7 +190,7 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{Event{evt: &publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite @@ -180,3 +208,4 @@ func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, fil var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) var _ datasource.StreamEvent = (*Event)(nil) +var _ datasource.MutableStreamEvent = (*MutableEvent)(nil) diff --git a/router/pkg/pubsub/redis/engine_datasource_factory.go b/router/pkg/pubsub/redis/engine_datasource_factory.go index 46f22e29b9..1e9f9866e4 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory.go @@ -66,7 +66,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri Provider: providerId, Channel: channel, FieldName: c.fieldName, - Event: Event{Data: eventData}, + Event: MutableEvent{Data: eventData}, } return evtCfg.MarshalJSONTemplate() diff --git a/router/pkg/pubsub/redis/engine_datasource_test.go b/router/pkg/pubsub/redis/engine_datasource_test.go index b322c8a60c..cc59d240f3 100644 --- a/router/pkg/pubsub/redis/engine_datasource_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_test.go @@ -23,9 +23,9 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: publishData{ - Provider: "test-provider", - Channel: "test-channel", - Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + Provider: "test-provider", + Channel: "test-channel", + Event: MutableEvent{Data: json.RawMessage(`{"message":"hello"}`)}, FieldName: "test-field", }, wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, @@ -33,9 +33,9 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "with special characters", config: publishData{ - Provider: "test-provider-id", - Channel: "channel-with-hyphens", - Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + Provider: "test-provider-id", + Channel: "channel-with-hyphens", + Event: MutableEvent{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, FieldName: "test-field", }, wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`,