diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md index 436dafe45b..21b035ff0b 100644 --- a/adr/cosmo-streams-v1.md +++ b/adr/cosmo-streams-v1.md @@ -21,24 +21,18 @@ The following interfaces will extend the existing logic in the custom modules. These provide additional control over subscriptions by providing hooks, which are invoked during specific events. - `SubscriptionOnStartHandler`: Called once at subscription start. -- `StreamBatchEventHook`: Called each time a batch of events is received from the provider. -- `StreamPublishEventHook`: Called each time a batch of events is going to be sent to the provider. +- `StreamReceiveEventHandler`: Triggered for each client/subscription when a batch of events is received from the provider, prior to delivery. +- `StreamPublishEventHandler`: Called each time a batch of events is going to be sent to the provider. ```go // STRUCTURES TO BE ADDED TO PUBSUB PACKAGE type ProviderType string const ( - ProviderTypeNats ProviderType = "nats" + ProviderTypeNats ProviderType = "nats" ProviderTypeKafka ProviderType = "kafka" ProviderTypeRedis ProviderType = "redis" } -// StreamHookError is used to customize the error messages and the behavior -type StreamHookError struct { - HttpError core.HttpError - CloseSubscription bool -} - // OperationContext already exists, we just have to add the Variables() method type OperationContext interface { Name() string @@ -48,8 +42,9 @@ type OperationContext interface { // each provider will have its own event type with custom fields // the StreamEvent interface is used to allow the hooks system to be provider-agnostic -// there could be common fields in future, but for now we don't need them -type StreamEvent interface {} +type StreamEvent interface { + GetData() []byte +} // SubscriptionEventConfiguration is the common interface for the subscription event configuration type SubscriptionEventConfiguration interface { @@ -67,7 +62,7 @@ type PublishEventConfiguration interface { RootFieldName() string } -type SubscriptionOnStartHookContext interface { +type SubscriptionOnStartHandlerContext interface { // Request is the original request received by the router. Request() *http.Request // Logger is the logger for the request @@ -85,34 +80,48 @@ type SubscriptionOnStartHookContext interface { type SubscriptionOnStartHandler interface { // OnSubscriptionOnStart is called once at subscription start - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. - SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error + // Returning an error will result in a GraphQL error being returned to the client + SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) error } -type StreamBatchEventHookContext interface { - // the request context - RequestContext() RequestContext - // the subscription event configuration +type StreamReceiveEventHandlerContext interface { + // Request is the initial client request that started the subscription + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration is the subscription event configuration SubscriptionEventConfiguration() SubscriptionEventConfiguration } -type StreamBatchEventHook interface { - // OnStreamEvents is called each time a batch of events is received from the provider - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. - OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +type StreamReceiveEventHandler interface { + // OnReceiveEvents is called each time a batch of events is received from the provider before delivering them to the client + // So for a single batch of events received from the provider, this hook will be called one time for each active subscription. + // It is important to optimize the logic inside this hook to avoid performance issues. + // Returning an error will result in a GraphQL error being returned to the client + OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []StreamEvent) ([]StreamEvent, error) } -type StreamPublishEventHookContext interface { - // the request context - RequestContext() RequestContext - // the publish event configuration +type StreamPublishEventHandlerContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // PublishEventConfiguration is the publish event configuration PublishEventConfiguration() PublishEventConfiguration } -type StreamPublishEventHook interface { +type StreamPublishEventHandler interface { // OnPublishEvents is called each time a batch of events is going to be sent to the provider - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. - OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) + // Returning an error will result in an error being returned and the client will see the mutation failing + OnPublishEvents(ctx StreamPublishEventHandlerContext, events []StreamEvent) ([]StreamEvent, error) } ``` @@ -154,7 +163,7 @@ type Employee @key(fields: "id", resolvable: false) { id: Int! @external } ``` -After publishing the schema, the developer will need to add the module to the cosmo streams engine. +After publishing the schema, the developer will need to add the module to the cosmo router. ### 2. Write the custom module @@ -177,39 +186,38 @@ func init() { type MyModule struct {} -func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { +func (m *MyModule) OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []core.StreamEvent) ([]core.StreamEvent, error) { // check if the provider is nats - if ctx.StreamContext().ProviderType() != pubsub.ProviderTypeNats { + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { return events, nil } // check if the provider id is the one expected by the module - if ctx.StreamContext().ProviderID() != "my-nats" { + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { return events, nil } - // check if the subject is the one expected by the module - natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) - if natsConfig.Subjects[0] != "employeeUpdates" { - return events, nil - } + // check if the subscription is the one expected by the module + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdates" { + return events, nil + } + + newEvents := make([]core.StreamEvent, 0, len(events)) // check if the client is authenticated - if ctx.RequestContext().Authentication() == nil { + if ctx.Authentication() == nil { // if the client is not authenticated, return no events - return events, nil + return newEvents, nil } // check if the client is allowed to subscribe to the stream - clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["allowedEntitiesIds"] if !found { - return events, fmt.Errorf("client is not allowed to subscribe to the stream") + return newEvents, fmt.Errorf("client is not allowed to subscribe to the stream") } - newEvents := make([]core.StreamEvent, 0, len(events)) - for _, evt := range events { - natsEvent, ok := evt.(*nats.NatsEvent); + natsEvent, ok := evt.(*nats.NatsEvent) if !ok { newEvents = append(newEvents, evt) continue @@ -266,7 +274,7 @@ func (m *MyModule) Module() core.ModuleInfo { // Interface guards var ( - _ core.StreamBatchEventHook = (*MyModule)(nil) + _ core.StreamReceiveEventHandler = (*MyModule)(nil) ) ``` @@ -321,7 +329,7 @@ func init() { type MyModule struct {} -func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) error { // check if the provider is nats if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { return nil @@ -332,20 +340,17 @@ func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error return nil } - // check if the subject is the one expected by the module - natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) - if natsConfig.Subjects[0] != "employeeUpdates" { - return nil - } + // check if the subscription is the one expected by the module + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdates" { + return nil + } // check if the client is authenticated if ctx.Authentication() == nil { // if the client is not authenticated, return an error - return &StreamHookError{ - HttpError: core.HttpError{ - Code: http.StatusUnauthorized, - Message: "client is not authenticated", - }, + return &core.HttpError{ + Code: http.StatusUnauthorized, + Message: "client is not authenticated", CloseSubscription: true, } } @@ -353,11 +358,9 @@ func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error // check if the client is allowed to subscribe to the stream clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["readEmployee"] if !found { - return &StreamHookError{ - HttpError: core.HttpError{ - Code: http.StatusForbidden, - Message: "client is not allowed to read employees", - }, + return &core.HttpError{ + Code: http.StatusForbidden, + Message: "client is not allowed to read employees", CloseSubscription: true, } } @@ -405,4 +408,4 @@ We could also generate the AsyncAPI specification from the schema and the events ## Generate hooks from AsyncAPI specifications -Building on the AsyncAPI integration, we could allow the user to define their streams using AsyncAPI and generate fully typesafe hooks with all events structures generated from the AsyncAPI specification. \ No newline at end of file +Building on the AsyncAPI integration, we could allow the user to define their streams using AsyncAPI and generate fully typesafe hooks with all events structures generated from the AsyncAPI specification. diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 6abb2c062e..8e52ec96c5 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -10,24 +10,28 @@ import ( "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // UpdateAvailability is the resolver for the updateAvailability field. func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID int, isAvailable bool) (*model.Employee, error) { storage.Set(employeeID, isAvailable) - err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + conf := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), - Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, - }) + } + evt := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + err := r.NatsPubSubByProviderID["default"].Publish(ctx, conf, []datasource.StreamEvent{evt}) if err != nil { return nil, err } - err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + + conf2 := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), - Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, - }) + } + evt2 := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, conf2, []datasource.StreamEvent{evt2}) if err != nil { return nil, err diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index 82a0a7e9f2..b9b426593c 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -10,6 +10,7 @@ import ( "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) @@ -19,10 +20,9 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood myNatsTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)) payload := fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID) if r.NatsPubSubByProviderID["default"] != nil { - err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + err := r.NatsPubSubByProviderID["default"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: myNatsTopic, - Event: nats.Event{Data: []byte(payload)}, - }) + }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) if err != nil { return nil, err } @@ -32,10 +32,9 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood defaultTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)) if r.NatsPubSubByProviderID["my-nats"] != nil { - err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: defaultTopic, - Event: nats.Event{Data: []byte(payload)}, - }) + }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) if err != nil { return nil, err } diff --git a/router-tests/events/events_config_test.go b/router-tests/events/events_config_test.go index f7e0739e1c..50d19dbaed 100644 --- a/router-tests/events/events_config_test.go +++ b/router-tests/events/events_config_test.go @@ -1,4 +1,4 @@ -package events +package events_test import ( "testing" diff --git a/router-tests/events/kafka_events_test.go b/router-tests/events/kafka_events_test.go index 3ad51a592c..37e54109a8 100644 --- a/router-tests/events/kafka_events_test.go +++ b/router-tests/events/kafka_events_test.go @@ -1,9 +1,8 @@ -package events +package events_test import ( "bufio" "bytes" - "context" "encoding/json" "fmt" "net/http" @@ -11,13 +10,13 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/wundergraph/cosmo/router/core" - "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/twmb/franz-go/pkg/kgo" + + "github.com/wundergraph/cosmo/router-tests/events" "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" "github.com/wundergraph/cosmo/router/pkg/config" ) @@ -74,7 +73,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -107,7 +106,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -130,7 +129,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -164,23 +163,23 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], ``) // Empty message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], ``) // Empty message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Invalid message received") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Cannot return null for non-nullable field 'Subscription.employeeUpdatedMyKafka.id'.") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) @@ -204,7 +203,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -248,7 +247,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -277,7 +276,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -321,7 +320,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -333,7 +332,7 @@ func TestKafkaEvents(t *testing.T) { require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) }) - ProduceKafkaMessage(t, xEnv, topics[1], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[1], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -366,7 +365,7 @@ func TestKafkaEvents(t *testing.T) { engineExecutionConfiguration.WebSocketClientReadTimeout = time.Millisecond * 100 }, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -399,7 +398,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -431,7 +430,7 @@ func TestKafkaEvents(t *testing.T) { core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval), }, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -447,10 +446,10 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) assertKafkaMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) assertKafkaMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") }) }) @@ -497,7 +496,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -530,7 +529,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, responseCh, func(t *testing.T, response struct { response *http.Response @@ -562,7 +561,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -595,7 +594,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, responseCh, func(t *testing.T, resp struct { response *http.Response @@ -672,7 +671,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -713,7 +712,7 @@ func TestKafkaEvents(t *testing.T) { // Events 1, 2, 11, and 12 should be included for i := uint32(1); i < 13; i++ { - ProduceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 1 || i == 2 || i == 11 || i == 12 { conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) gErr := conn.ReadJSON(&msg) @@ -739,7 +738,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -780,7 +779,7 @@ func TestKafkaEvents(t *testing.T) { // Events 1, 2, 11, and 12 should be included for i := uint32(1); i < 13; i++ { - ProduceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 1 || i == 2 || i == 11 || i == 12 { conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) gErr := conn.ReadJSON(&msg) @@ -806,7 +805,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -835,10 +834,10 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) // The message should be ignored because "1" does not equal 1 - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id":1}`) // This message should be delivered because it matches the filter - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":12}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id":12}`) conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) readErr := conn.ReadJSON(&msg) require.NoError(t, readErr) @@ -861,7 +860,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -894,23 +893,23 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{asas`) // Invalid message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{asas`) // Invalid message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Invalid message received") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id":1}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Cannot return null for non-nullable field 'Subscription.employeeUpdatedMyKafka.id'.") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) @@ -932,7 +931,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) // Send a mutation to trigger the first subscription resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -940,7 +939,7 @@ func TestKafkaEvents(t *testing.T) { }) require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) - records, err := readKafkaMessages(xEnv, topics[0], 1) + records, err := events.ReadKafkaMessages(xEnv, KafkaWaitTimeout, topics[0], 1) require.NoError(t, err) require.Equal(t, 1, len(records)) require.Equal(t, `{"employeeID":3,"update":{"name":"name test"}}`, string(records[0].Value)) @@ -980,7 +979,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -1024,7 +1023,7 @@ func TestKafkaEvents(t *testing.T) { // Events 1, 3, 4, 7, and 11 should be included for i := int(MsgCount); i > 0; i-- { - ProduceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 1 || i == 3 || i == 4 || i == 7 || i == 11 { conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) jsonErr := conn.ReadJSON(&msg) @@ -1041,20 +1040,3 @@ func TestKafkaEvents(t *testing.T) { }) }) } - -func readKafkaMessages(xEnv *testenv.Environment, topicName string, msgs int) ([]*kgo.Record, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - client, err := kgo.NewClient( - kgo.SeedBrokers(xEnv.GetKafkaSeeds()...), - kgo.ConsumeTopics(xEnv.GetPubSubName(topicName)), - ) - if err != nil { - return nil, err - } - - fetchs := client.PollRecords(ctx, msgs) - - return fetchs.Records(), nil -} diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index 9e1558db24..0add1e361e 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -1,4 +1,4 @@ -package events +package events_test import ( "bufio" diff --git a/router-tests/events/redis_events_test.go b/router-tests/events/redis_events_test.go index f6c9e54d13..2980ae61d5 100644 --- a/router-tests/events/redis_events_test.go +++ b/router-tests/events/redis_events_test.go @@ -1,22 +1,20 @@ -package events +package events_test import ( "bufio" "bytes" - "context" "encoding/json" "fmt" "net/http" - "net/url" "testing" "time" - "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/wundergraph/cosmo/router/core" "github.com/hasura/go-graphql-client" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/events" "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/pkg/config" ) @@ -104,7 +102,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // process the message select { @@ -170,7 +168,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce an empty message - ProduceRedisMessage(t, xEnv, topics[0], ``) + events.ProduceRedisMessage(t, xEnv, topics[0], ``) // process the message select { case subscriptionArgs := <-subscriptionArgsCh: @@ -181,7 +179,7 @@ func TestRedisEvents(t *testing.T) { t.Fatal("timeout waiting for first message error") } - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message select { case subscriptionArgs := <-subscriptionArgsCh: require.NoError(t, subscriptionArgs.errValue) @@ -191,7 +189,7 @@ func TestRedisEvents(t *testing.T) { } // Missing entity = Resolver error - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) select { case subscriptionArgs := <-subscriptionArgsCh: var gqlErr graphql.Errors @@ -202,7 +200,7 @@ func TestRedisEvents(t *testing.T) { } // Correct message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) select { case subscriptionArgs := <-subscriptionArgsCh: require.NoError(t, subscriptionArgs.errValue) @@ -273,7 +271,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the first subscription select { @@ -354,7 +352,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the first subscription select { @@ -375,7 +373,7 @@ func TestRedisEvents(t *testing.T) { } // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) // read the message from the first subscription select { @@ -451,7 +449,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the subscription select { @@ -509,12 +507,12 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the subscription assertRedisMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdates\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the subscription assertRedisMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdates\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") }) @@ -590,7 +588,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message so that the subscription is triggered - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // get the client response var clientRet struct { @@ -663,7 +661,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message so that the subscription is triggered - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // get the client response var clientRet struct { @@ -792,7 +790,7 @@ func TestRedisEvents(t *testing.T) { // Events 1, 3, 4, 7, and 11 should be included for i := MsgCount; i > 0; i-- { - ProduceRedisMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceRedisMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 11 || i == 7 || i == 4 || i == 3 || i == 1 { gErr := conn.ReadJSON(&msg) @@ -853,7 +851,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce an invalid message - ProduceRedisMessage(t, xEnv, topics[0], `{asas`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{asas`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -865,7 +863,7 @@ func TestRedisEvents(t *testing.T) { } // produce a correct message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -876,7 +874,7 @@ func TestRedisEvents(t *testing.T) { } // produce a message with a missing entity - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -888,7 +886,7 @@ func TestRedisEvents(t *testing.T) { } // produce a correct message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -920,7 +918,7 @@ func TestRedisEvents(t *testing.T) { NoRetryClient: true, }, func(t *testing.T, xEnv *testenv.Environment) { // start reading the messages from the channel - msgCh, err := readRedisMessages(t, xEnv, channels[0]) + msgCh, err := events.ReadRedisMessages(t, xEnv, channels[0]) require.NoError(t, err) // send a mutation to trigger the first subscription @@ -991,7 +989,7 @@ func TestRedisClusterEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message select { @@ -1026,7 +1024,7 @@ func TestRedisClusterEvents(t *testing.T) { NoRetryClient: true, }, func(t *testing.T, xEnv *testenv.Environment) { // start reading the messages from the channel - msgCh, err := readRedisMessages(t, xEnv, channels[0]) + msgCh, err := events.ReadRedisMessages(t, xEnv, channels[0]) require.NoError(t, err) // send a mutation to produce a message @@ -1046,30 +1044,3 @@ func TestRedisClusterEvents(t *testing.T) { }) } - -func readRedisMessages(t *testing.T, xEnv *testenv.Environment, channelName string) (<-chan *redis.Message, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - parsedURL, err := url.Parse(xEnv.RedisHosts[0]) - if err != nil { - return nil, err - } - var redisConn redis.UniversalClient - if !xEnv.RedisWithClusterMode { - redisConn = redis.NewClient(&redis.Options{ - Addr: parsedURL.Host, - }) - } else { - redisConn = redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: []string{parsedURL.Host}, - }) - } - sub := redisConn.Subscribe(ctx, xEnv.GetPubSubName(channelName)) - t.Cleanup(func() { - sub.Close() - redisConn.Close() - }) - - return sub.Channel(), nil -} diff --git a/router-tests/events/event_helpers.go b/router-tests/events/utils.go similarity index 55% rename from router-tests/events/event_helpers.go rename to router-tests/events/utils.go index 48d97e90c4..b8619c3368 100644 --- a/router-tests/events/event_helpers.go +++ b/router-tests/events/utils.go @@ -2,19 +2,37 @@ package events import ( "context" + "net/url" + "testing" + "time" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" "github.com/twmb/franz-go/pkg/kgo" "github.com/wundergraph/cosmo/router-tests/testenv" - "net/url" - "testing" - "time" ) -const waitTimeout = time.Second * 30 +func KafkaEnsureTopicExists(t *testing.T, xEnv *testenv.Environment, timeout time.Duration, topics ...string) { + // Delete topic for idempotency + deleteCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + prefixedTopics := make([]string, 0, len(topics)) + for _, topic := range topics { + prefixedTopics = append(prefixedTopics, xEnv.GetPubSubName(topic)) + } -func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName string, message string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err := xEnv.KafkaAdminClient.DeleteTopics(deleteCtx, prefixedTopics...) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + _, err = xEnv.KafkaAdminClient.CreateTopics(ctx, 1, 1, nil, prefixedTopics...) + require.NoError(t, err) +} + +func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, timeout time.Duration, topicName string, message string) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() pErrCh := make(chan error) @@ -26,7 +44,7 @@ func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName stri pErrCh <- err }) - testenv.AwaitChannelWithT(t, waitTimeout, pErrCh, func(t *testing.T, pErr error) { + testenv.AwaitChannelWithT(t, timeout, pErrCh, func(t *testing.T, pErr error) { require.NoError(t, pErr) }) @@ -34,23 +52,22 @@ func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName stri require.NoError(t, fErr) } -func EnsureTopicExists(t *testing.T, xEnv *testenv.Environment, topics ...string) { - // Delete topic for idempotency - deleteCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) +func ReadKafkaMessages(xEnv *testenv.Environment, timeout time.Duration, topicName string, msgs int) ([]*kgo.Record, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - prefixedTopics := make([]string, 0, len(topics)) - for _, topic := range topics { - prefixedTopics = append(prefixedTopics, xEnv.GetPubSubName(topic)) - } - _, err := xEnv.KafkaAdminClient.DeleteTopics(deleteCtx, prefixedTopics...) - require.NoError(t, err) + client, err := kgo.NewClient( + kgo.SeedBrokers(xEnv.GetKafkaSeeds()...), + kgo.ConsumeTopics(xEnv.GetPubSubName(topicName)), + ) + if err != nil { + return nil, err + } + defer client.Close() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + fetchs := client.PollRecords(ctx, msgs) - _, err = xEnv.KafkaAdminClient.CreateTopics(ctx, 1, 1, nil, prefixedTopics...) - require.NoError(t, err) + return fetchs.Records(), nil } func ProduceRedisMessage(t *testing.T, xEnv *testenv.Environment, topicName string, message string) { @@ -72,10 +89,33 @@ func ProduceRedisMessage(t *testing.T, xEnv *testenv.Environment, topicName stri }) } - defer func() { - _ = redisConn.Close() - }() - intCmd := redisConn.Publish(ctx, xEnv.GetPubSubName(topicName), message) require.NoError(t, intCmd.Err()) } + +func ReadRedisMessages(t *testing.T, xEnv *testenv.Environment, channelName string) (<-chan *redis.Message, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + parsedURL, err := url.Parse(xEnv.RedisHosts[0]) + if err != nil { + return nil, err + } + var redisConn redis.UniversalClient + if !xEnv.RedisWithClusterMode { + redisConn = redis.NewClient(&redis.Options{ + Addr: parsedURL.Host, + }) + } else { + redisConn = redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{parsedURL.Host}, + }) + } + sub := redisConn.Subscribe(ctx, xEnv.GetPubSubName(channelName)) + t.Cleanup(func() { + sub.Close() + redisConn.Close() + }) + + return sub.Channel(), nil +} diff --git a/router-tests/go.mod b/router-tests/go.mod index 479d44590c..69af587c0b 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -27,7 +27,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20250912064154-106e871ee32e github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 diff --git a/router-tests/go.sum b/router-tests/go.sum index 947f5ba76c..07c3dbd780 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -352,10 +352,8 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoTjFDDcgaECCCR2aTePqMu9QBmPbyhqIYOhV0= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 h1:VCfCX/xmpBGQLhTHJMHLugzJrXJk/smjLRAEruCI0HY= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 h1:PhKYGyTBFM0JIihHLQa6tD5Al6GVFIPuJxi2T+DEiB0= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router-tests/modules/start-subscription/module.go b/router-tests/modules/start-subscription/module.go index fd5a9e0088..ffa94ef1f0 100644 --- a/router-tests/modules/start-subscription/module.go +++ b/router-tests/modules/start-subscription/module.go @@ -12,7 +12,7 @@ const myModuleID = "startSubscriptionModule" type StartSubscriptionModule struct { Logger *zap.Logger - Callback func(ctx core.SubscriptionOnStartHookContext) error + Callback func(ctx core.SubscriptionOnStartHandlerContext) error CallbackOnOriginResponse func(response *http.Response, ctx core.RequestContext) *http.Response } @@ -23,7 +23,7 @@ func (m *StartSubscriptionModule) Provision(ctx *core.ModuleContext) error { return nil } -func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHookContext) error { +func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHandlerContext) error { m.Logger.Info("SubscriptionOnStart Hook has been run") diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index ad286d54ef..b9d5e2f0ac 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -90,7 +90,7 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { return nil } @@ -179,9 +179,9 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { callbackCalled <- true - return core.NewStreamHookError(nil, "subscription closed", http.StatusOK, "") + return core.NewHttpGraphqlError("subscription closed", http.StatusText(http.StatusOK), http.StatusOK) }, }, }, @@ -261,7 +261,7 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { employeeId := ctx.Operation().Variables().GetInt64("employeeID") if employeeId != 1 { return nil @@ -365,8 +365,8 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { - return core.NewStreamHookError(errors.New("test error"), "test error", http.StatusLoopDetected, http.StatusText(http.StatusLoopDetected)) + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + return core.NewHttpGraphqlError("test error", http.StatusText(http.StatusLoopDetected), http.StatusLoopDetected) }, }, }, @@ -509,7 +509,7 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { ctx.WriteEvent(&core.EngineEvent{ Data: []byte(`{"data":{"countEmp":1000}}`), }) @@ -593,8 +593,8 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { - return core.NewStreamHookError(errors.New("subscription closed"), "subscription closed", http.StatusOK, "NotFound") + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + return core.NewHttpGraphqlError("subscription closed", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, CallbackOnOriginResponse: func(response *http.Response, ctx core.RequestContext) *http.Response { originResponseCalled <- response diff --git a/router-tests/modules/stream-publish/module.go b/router-tests/modules/stream-publish/module.go new file mode 100644 index 0000000000..e5553058ea --- /dev/null +++ b/router-tests/modules/stream-publish/module.go @@ -0,0 +1,49 @@ +package publish + +import ( + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" +) + +const myModuleID = "publishModule" + +type PublishModule struct { + Logger *zap.Logger + Callback func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +func (m *PublishModule) Provision(ctx *core.ModuleContext) error { + // Assign the logger to the module for non-request related logging + m.Logger = ctx.Logger + + return nil +} + +func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + m.Logger.Info("Publish Hook has been run") + + if m.Callback != nil { + return m.Callback(ctx, events) + } + + return events, nil +} + +func (m *PublishModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &PublishModule{} + }, + } +} + +// Interface guard +var ( + _ core.StreamPublishEventHandler = (*PublishModule)(nil) +) diff --git a/router-tests/modules/stream-receive/module.go b/router-tests/modules/stream-receive/module.go new file mode 100644 index 0000000000..640218ad00 --- /dev/null +++ b/router-tests/modules/stream-receive/module.go @@ -0,0 +1,49 @@ +package batch + +import ( + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" +) + +const myModuleID = "streamReceiveModule" + +type StreamReceiveModule struct { + Logger *zap.Logger + Callback func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { + // Assign the logger to the module for non-request related logging + m.Logger = ctx.Logger + + return nil +} + +func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + m.Logger.Info("Stream Hook has been run") + + if m.Callback != nil { + return m.Callback(ctx, events) + } + + return events, nil +} + +func (m *StreamReceiveModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &StreamReceiveModule{} + }, + } +} + +// Interface guard +var ( + _ core.StreamReceiveEventHandler = (*StreamReceiveModule)(nil) +) diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go new file mode 100644 index 0000000000..6fb7485dc3 --- /dev/null +++ b/router-tests/modules/stream_publish_test.go @@ -0,0 +1,315 @@ +package module_test + +import ( + "encoding/json" + "net/http" + "strconv" + "testing" + "time" + + "go.uber.org/zap/zapcore" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/events" + stream_publish "github.com/wundergraph/cosmo/router-tests/modules/stream-publish" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" +) + +func TestPublishHook(t *testing.T) { + t.Parallel() + + t.Run("Test Publish hook is called", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Publish kafka hook allows to set headers", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Headers["x-test"] = []byte("test") + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) + require.NoError(t, err) + require.Len(t, records, 1) + header := records[0].Headers[0] + require.Equal(t, "x-test", header.Key) + require.Equal(t, []byte("test"), header.Value) + }) + }) + + t.Run("Test kafka publish error is returned and messages sent", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": false}}}`, resOne.Body) + require.Equal(t, resOne.Response.StatusCode, 200) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") + assert.Len(t, requestLog2.All(), 1) + + records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) + require.NoError(t, err) + require.Len(t, records, 1) + }) + }) + + t.Run("Test nats publish error is returned and messages sent", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate, + EnableNats: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + firstSub, err := xEnv.NatsConnectionDefault.SubscribeSync(xEnv.GetPubSubName("employeeUpdatedMyNats.3")) + require.NoError(t, err) + t.Cleanup(func() { + _ = firstSub.Unsubscribe() + }) + require.NoError(t, xEnv.NatsConnectionDefault.Flush()) + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation UpdateEmployeeNats($update: UpdateEmployeeInput!) { + updateEmployeeMyNats(id: 3, update: $update) {success} + }`, + Variables: json.RawMessage(`{"update":{"name":"Stefan Avramovic","email":"avramovic@wundergraph.com"}}`), + }) + assert.JSONEq(t, `{"data": {"updateEmployeeMyNats": {"success": false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") + assert.Len(t, requestLog2.All(), 1) + + msgOne, err := firstSub.NextMsg(5 * time.Second) + require.NoError(t, err) + require.Equal(t, xEnv.GetPubSubName("employeeUpdatedMyNats.3"), msgOne.Subject) + require.Equal(t, `{"id":3,"update":{"name":"Stefan Avramovic","email":"avramovic@wundergraph.com"}}`, string(msgOne.Data)) + require.NoError(t, err) + }) + }) + + t.Run("Test redis publish error is returned and messages sent", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsRedisJSONTemplate, + EnableRedis: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + records, err := events.ReadRedisMessages(t, xEnv, "employeeUpdatedMyRedis") + require.NoError(t, err) + + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyRedis(id: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyRedis": {"success": false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") + assert.Len(t, requestLog2.All(), 1) + + require.Len(t, records, 1) + }) + }) + + t.Run("Test kafka module publish with argument in header", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { + return events, nil + } + + employeeID := ctx.Operation().Variables().GetInt("employeeID") + + newEvents := []datasource.StreamEvent{} + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + if evt.Headers == nil { + evt.Headers = map[string][]byte{} + } + evt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) + newEvents = append(newEvents, event) + } + return newEvents, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation UpdateEmployeeKafka($employeeID: Int!) { updateEmployeeMyKafka(employeeID: $employeeID, update: {name: "name test"}) { success } }`, + Variables: json.RawMessage(`{"employeeID": 3}`), + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": true}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) + require.NoError(t, err) + require.Len(t, records, 1) + header := records[0].Headers[0] + require.Equal(t, "x-employee-id", header.Key) + require.Equal(t, []byte("3"), header.Value) + }) + }) +} diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go new file mode 100644 index 0000000000..a1658dc35c --- /dev/null +++ b/router-tests/modules/stream_receive_test.go @@ -0,0 +1,521 @@ +package module_test + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + integration "github.com/wundergraph/cosmo/router-tests" + "github.com/wundergraph/cosmo/router-tests/events" + "github.com/wundergraph/cosmo/router-tests/jwks" + stream_receive "github.com/wundergraph/cosmo/router-tests/modules/stream-receive" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/authentication" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestReceiveHook(t *testing.T) { + t.Parallel() + + const Timeout = time.Second * 10 + + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + + t.Run("Test Receive hook is called", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Receive hook could change events", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Receive hook change events of one of multiple subscriptions", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if ctx.Authentication() == nil { + return events, nil + } + if val, ok := ctx.Authentication().Claims()["sub"]; !ok || val != "user-2" { + return events, nil + } + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + } + + return events, nil + }, + }, + }, + } + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + defer authServer.Close() + + JwksName := "my-jwks-server" + + tokenDecoder, _ := authentication.NewJwksTokenDecoder(integration.NewContextWithCancel(t), zap.NewNop(), []authentication.JWKSConfig{{ + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + }}) + jwksOpts := authentication.HttpHeaderAuthenticatorOptions{ + Name: JwksName, + TokenDecoder: tokenDecoder, + } + + authenticator, err := authentication.NewHttpHeaderAuthenticator(jwksOpts) + require.NoError(t, err) + authenticators := []authentication.Authenticator{authenticator} + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + token, err := authServer.Token(map[string]interface{}{ + "sub": "user-2", + }) + require.NoError(t, err) + + headers := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + client2 := graphql.NewSubscriptionClient(surl) + client2.WithWebSocketOptions(graphql.WebsocketOptions{ + HTTPHeader: headers, + }) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + subscriptionArgsCh2 := make(chan kafkaSubscriptionArgs) + subscriptionTwoID, err := client2.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh2 <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionTwoID) + + clientRunCh2 := make(chan error) + go func() { + clientRunCh2 <- client2.Run() + }() + + xEnv.WaitForSubscriptionCount(2, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + assert.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh2, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + assert.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(args.dataValue)) + }) + + unSub1Err := client.Unsubscribe(subscriptionOneID) + require.NoError(t, unSub1Err) + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + unSub2Err := client2.Unsubscribe(subscriptionTwoID) + require.NoError(t, unSub2Err) + require.NoError(t, client2.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh2, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 2) + }) + }) + + t.Run("Test Receive hook can access custom header", func(t *testing.T) { + t.Parallel() + + customHeader := http.CanonicalHeaderKey("X-Custom-Header") + + cfg := config.Config{ + Graph: config.Graph{}, + + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if val, ok := ctx.Request().Header[customHeader]; !ok || val[0] != "Test" { + return events, nil + } + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + headers := http.Header{ + customHeader: []string{"Test"}, + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + client.WithWebSocketOptions(graphql.WebsocketOptions{ + HTTPHeader: headers, + }) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + assert.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(args.dataValue)) + }) + + unSub1Err := client.Unsubscribe(subscriptionOneID) + require.NoError(t, unSub1Err) + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Batch hook error should close Kafka clients and subscriptions", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return nil, errors.New("test error from streamevents hook") + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Wait for server to close the subscription connection + xEnv.WaitForSubscriptionCount(0, Timeout) + + // Verify that client.Run() completed when server closed the connection + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "client should have completed when server closed connection") + + xEnv.WaitForTriggerCount(0, Timeout) + }) + }) +} diff --git a/router-tests/modules/streams_hooks_combined_test.go b/router-tests/modules/streams_hooks_combined_test.go new file mode 100644 index 0000000000..78639dd052 --- /dev/null +++ b/router-tests/modules/streams_hooks_combined_test.go @@ -0,0 +1,149 @@ +package module + +import ( + "encoding/json" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/events" + stream_publish "github.com/wundergraph/cosmo/router-tests/modules/stream-publish" + stream_receive "github.com/wundergraph/cosmo/router-tests/modules/stream-receive" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" + "go.uber.org/zap/zapcore" +) + +func TestStreamsHooksCombined(t *testing.T) { + t.Parallel() + + t.Run("Test kafka modules can depend on each other", func(t *testing.T) { + t.Parallel() + + type event struct { + data []byte + err error + } + + const Timeout = time.Second * 10 + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + + if string(evt.Headers["x-publishModule"]) == "i_was_here" { + evt.Data = []byte(`{"__typename":"Employee","id": 2,"update":{"name":"irrelevant"}}`) + } + } + + return events, nil + }, + }, + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { + return events, nil + } + + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Headers["x-publishModule"] = []byte("i_was_here") + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}, &stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + // start a subscriber + var subscriptionPayload struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionEventsChan := make(chan event) + subscriptionID, err := client.Subscribe(&subscriptionPayload, nil, func(dataValue []byte, errValue error) error { + subscriptionEventsChan <- event{ + data: dataValue, + err: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + clientRunChan := make(chan error) + go func() { + clientRunChan <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + // publish a message to broker via mutation + // and let publish hook modify the message + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation UpdateEmployeeKafka($employeeID: Int!) { updateEmployeeMyKafka(employeeID: $employeeID, update: {name: "name test"}) { success } }`, + Variables: json.RawMessage(`{"employeeID": 3}`), + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": true}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + // wait for the message to be received by the subscriber + testenv.AwaitChannelWithT(t, Timeout, subscriptionEventsChan, func(t *testing.T, args event) { + require.NoError(t, args.err) + // verify that the stream batch hook modified the message, + // which it only does if the publish hook was run before it + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":2,"details":{"forename":"Dustin","surname":"Deus"}}}`, string(args.data)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunChan, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog = xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) +} diff --git a/router-tests/prometheus_stream_metrics_test.go b/router-tests/prometheus_stream_metrics_test.go index 30fa87fe16..ac6d23d767 100644 --- a/router-tests/prometheus_stream_metrics_test.go +++ b/router-tests/prometheus_stream_metrics_test.go @@ -44,7 +44,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnablePrometheusStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, "employeeUpdated") + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) @@ -91,7 +91,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnablePrometheusStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, topic) + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topic) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -115,7 +115,7 @@ func TestFlakyEventMetrics(t *testing.T) { go func() { clientRunCh <- client.Run() }() xEnv.WaitForSubscriptionCount(1, WaitTimeout) - events.ProduceKafkaMessage(t, xEnv, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, time.Second, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, WaitTimeout, subscriptionArgsCh, func(t *testing.T, args subscriptionArgs) { require.NoError(t, args.errValue) diff --git a/router-tests/telemetry/stream_metrics_test.go b/router-tests/telemetry/stream_metrics_test.go index 72ac7c654f..bac9aee748 100644 --- a/router-tests/telemetry/stream_metrics_test.go +++ b/router-tests/telemetry/stream_metrics_test.go @@ -45,7 +45,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnableOTLPStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, "employeeUpdated") + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) @@ -96,7 +96,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnableOTLPStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, topic) + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topic) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -120,7 +120,7 @@ func TestFlakyEventMetrics(t *testing.T) { go func() { clientRunCh <- client.Run() }() xEnv.WaitForSubscriptionCount(1, WaitTimeout) - events.ProduceKafkaMessage(t, xEnv, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, time.Second, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, WaitTimeout, subscriptionArgsCh, func(t *testing.T, args subscriptionArgs) { require.NoError(t, args.errValue) diff --git a/router/.mockery.yml b/router/.mockery.yml index 558bca2185..436ed0eb14 100644 --- a/router/.mockery.yml +++ b/router/.mockery.yml @@ -21,12 +21,6 @@ packages: github.com/wundergraph/cosmo/router/pkg/pubsub/nats: interfaces: Adapter: - github.com/wundergraph/cosmo/router/pkg/pubsub/kafka: - interfaces: - Adapter: - github.com/wundergraph/cosmo/router/pkg/pubsub/redis: - interfaces: - Adapter: github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve: config: dir: 'pkg/pubsub/datasource' diff --git a/router/core/errors.go b/router/core/errors.go index 44e05f327b..2ce688bbef 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -35,7 +35,7 @@ const ( errorTypeInvalidWsSubprotocol errorTypeEDFSInvalidMessage errorTypeMergeResult - errorTypeStreamHookError + errorTypeHttpError ) type ( @@ -90,9 +90,9 @@ func getErrorType(err error) errorType { if errors.As(err, &mergeResultErr) { return errorTypeMergeResult } - var streamHookErr *StreamHookError - if errors.As(err, &streamHookErr) { - return errorTypeStreamHookError + var httpError *httpGraphqlError + if errors.As(err, &httpError) { + return errorTypeHttpError } return errorTypeUnknown } diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index 70f3e7917c..d7c72fe579 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -481,6 +481,17 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod for i, fn := range l.subscriptionHooks.onStart { subscriptionOnStartFns[i] = NewPubSubSubscriptionOnStartHook(fn) } + + onPublishEventsFns := make([]pubsub_datasource.OnPublishEventsFn, len(l.subscriptionHooks.onPublishEvents)) + for i, fn := range l.subscriptionHooks.onPublishEvents { + onPublishEventsFns[i] = NewPubSubOnPublishEventsHook(fn) + } + + onReceiveEventsFns := make([]pubsub_datasource.OnReceiveEventsFn, len(l.subscriptionHooks.onReceiveEvents)) + for i, fn := range l.subscriptionHooks.onReceiveEvents { + onReceiveEventsFns[i] = NewPubSubOnReceiveEventsHook(fn) + } + factoryProviders, factoryDataSources, err := pubsub.BuildProvidersAndDataSources( l.ctx, routerEngineConfig.Events, @@ -489,8 +500,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod pubSubDS, l.resolver.InstanceData().HostName, l.resolver.InstanceData().ListenAddress, - pubsub.Hooks{ + pubsub_datasource.Hooks{ SubscriptionOnStart: subscriptionOnStartFns, + OnReceiveEvents: onReceiveEventsFns, + OnPublishEvents: onPublishEventsFns, }, ) if err != nil { diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index f387d73e6c..845b8bdac0 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -400,21 +400,21 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } - case errorTypeStreamHookError: - var streamHookErr *StreamHookError - if !errors.As(err, &streamHookErr) { + case errorTypeHttpError: + var httpErr *httpGraphqlError + if !errors.As(err, &httpErr) { response.Errors[0].Message = "Internal server error" return } - response.Errors[0].Message = streamHookErr.Message() - if streamHookErr.Code() != "" || streamHookErr.StatusCode() != 0 { + response.Errors[0].Message = httpErr.Message() + if httpErr.ExtensionCode() != "" || httpErr.StatusCode() != 0 { response.Errors[0].Extensions = &Extensions{ - Code: streamHookErr.Code(), - StatusCode: streamHookErr.StatusCode(), + Code: httpErr.ExtensionCode(), + StatusCode: httpErr.StatusCode(), } } if isHttpResponseWriter { - httpWriter.WriteHeader(streamHookErr.StatusCode()) + httpWriter.WriteHeader(httpErr.StatusCode()) } } diff --git a/router/core/router.go b/router/core/router.go index 9f07bd723a..04303023f0 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -670,6 +670,14 @@ func (r *Router) initModules(ctx context.Context) error { r.subscriptionHooks.onStart = append(r.subscriptionHooks.onStart, handler.SubscriptionOnStart) } + if handler, ok := moduleInstance.(StreamPublishEventHandler); ok { + r.subscriptionHooks.onPublishEvents = append(r.subscriptionHooks.onPublishEvents, handler.OnPublishEvents) + } + + if handler, ok := moduleInstance.(StreamReceiveEventHandler); ok { + r.subscriptionHooks.onReceiveEvents = append(r.subscriptionHooks.onReceiveEvents, handler.OnReceiveEvents) + } + r.modules = append(r.modules, moduleInstance) r.logger.Info("Module registered", diff --git a/router/core/router_config.go b/router/core/router_config.go index ac4f26d4c7..3e282d3c65 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -17,6 +17,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/health" "github.com/wundergraph/cosmo/router/pkg/mcpserver" rmetric "github.com/wundergraph/cosmo/router/pkg/metric" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" "go.opentelemetry.io/otel/propagation" sdkmetric "go.opentelemetry.io/otel/sdk/metric" @@ -26,7 +27,9 @@ import ( ) type subscriptionHooks struct { - onStart []func(ctx SubscriptionOnStartHookContext) error + onStart []func(ctx SubscriptionOnStartHandlerContext) error + onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) } type Config struct { diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index 505bbfc44f..e3279c811d 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -1,7 +1,9 @@ package core import ( + "context" "net/http" + "slices" "github.com/wundergraph/cosmo/router/pkg/authentication" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" @@ -10,43 +12,7 @@ import ( "go.uber.org/zap" ) -// StreamHookError is used to customize the error messages and the behavior -type StreamHookError struct { - err error - message string - statusCode int - code string -} - -func (e *StreamHookError) Error() string { - if e.err != nil { - return e.err.Error() - } - return e.message -} - -func (e *StreamHookError) Message() string { - return e.message -} - -func (e *StreamHookError) StatusCode() int { - return e.statusCode -} - -func (e *StreamHookError) Code() string { - return e.code -} - -func NewStreamHookError(err error, message string, statusCode int, code string) *StreamHookError { - return &StreamHookError{ - err: err, - message: message, - statusCode: statusCode, - code: code, - } -} - -type SubscriptionOnStartHookContext interface { +type SubscriptionOnStartHandlerContext interface { // Request is the original request received by the router. Request() *http.Request // Logger is the logger for the request @@ -62,11 +28,39 @@ type SubscriptionOnStartHookContext interface { WriteEvent(event datasource.StreamEvent) bool } -type pubSubSubscriptionOnStartHookContext struct { - request *http.Request +type pubSubPublishEventHookContext struct { + request *http.Request logger *zap.Logger operation OperationContext authentication authentication.Authentication + publishEventConfiguration datasource.PublishEventConfiguration +} + +func (c *pubSubPublishEventHookContext) Request() *http.Request { + return c.request +} + +func (c *pubSubPublishEventHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *pubSubPublishEventHookContext) Operation() OperationContext { + return c.operation +} + +func (c *pubSubPublishEventHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *pubSubPublishEventHookContext) PublishEventConfiguration() datasource.PublishEventConfiguration { + return c.publishEventConfiguration +} + +type pubSubSubscriptionOnStartHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication subscriptionEventConfiguration datasource.SubscriptionEventConfiguration writeEventHook func(data []byte) } @@ -106,11 +100,17 @@ func (e *EngineEvent) GetData() []byte { return e.Data } +func (e *EngineEvent) Clone() datasource.StreamEvent { + return &EngineEvent{ + Data: slices.Clone(e.Data), + } +} + type engineSubscriptionOnStartHookContext struct { - request *http.Request - logger *zap.Logger - operation OperationContext - authentication authentication.Authentication + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication writeEventHook func(data []byte) } @@ -143,11 +143,11 @@ func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration() type SubscriptionOnStartHandler interface { // SubscriptionOnStart is called once at subscription start // The error is propagated to the client. - SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error + SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) error } // NewPubSubSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a pubsub.SubscriptionOnStartFn -func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) datasource.SubscriptionOnStartFn { +func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerContext) error) datasource.SubscriptionOnStartFn { if fn == nil { return nil } @@ -155,10 +155,10 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration) error { requestContext := getRequestContext(resolveCtx.Context) hookCtx := &pubSubSubscriptionOnStartHookContext{ - request: requestContext.Request(), - logger: requestContext.Logger(), - operation: requestContext.Operation(), - authentication: requestContext.Authentication(), + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, writeEventHook: resolveCtx.Updater, } @@ -168,7 +168,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext } // NewEngineSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a graphql_datasource.SubscriptionOnStartFn -func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) graphql_datasource.SubscriptionOnStartFn { +func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerContext) error) graphql_datasource.SubscriptionOnStartFn { if fn == nil { return nil } @@ -176,9 +176,9 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext return func(resolveCtx resolve.StartupHookContext, input []byte) error { requestContext := getRequestContext(resolveCtx.Context) hookCtx := &engineSubscriptionOnStartHookContext{ - request: requestContext.Request(), - logger: requestContext.Logger(), - operation: requestContext.Operation(), + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), authentication: requestContext.Authentication(), writeEventHook: resolveCtx.Updater, } @@ -186,3 +186,112 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext return fn(hookCtx) } } + +type StreamReceiveEventHandlerContext interface { + // Request is the initial client request that started the subscription + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration the subscription event configuration + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration +} + +type StreamReceiveEventHandler interface { + // OnReceiveEvents is called each time a batch of events is received from the provider before delivering them to the + // client. So for a single batch of events received from the provider, this hook will be called one time for each + // active subscription. + // It is important to optimize the logic inside this hook to avoid performance issues. + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a + // StreamHookError. + OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +type StreamPublishEventHandlerContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // PublishEventConfiguration the publish event configuration + PublishEventConfiguration() datasource.PublishEventConfiguration +} + +type StreamPublishEventHandler interface { + // OnPublishEvents is called each time a batch of events is going to be sent to the provider + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a + // StreamHookError. + OnPublishEvents(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnPublishEventsFn { + if fn == nil { + return nil + } + + return func(ctx context.Context, pubConf datasource.PublishEventConfiguration, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + requestContext := getRequestContext(ctx) + hookCtx := &pubSubPublishEventHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + publishEventConfiguration: pubConf, + } + + return fn(hookCtx, evts) + } +} + +type pubSubStreamReceiveEventHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + subscriptionEventConfiguration datasource.SubscriptionEventConfiguration +} + +func (c *pubSubStreamReceiveEventHookContext) Request() *http.Request { + return c.request +} + +func (c *pubSubStreamReceiveEventHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *pubSubStreamReceiveEventHookContext) Operation() OperationContext { + return c.operation +} + +func (c *pubSubStreamReceiveEventHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { + return c.subscriptionEventConfiguration +} + +func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnReceiveEventsFn { + if fn == nil { + return nil + } + + return func(ctx context.Context, subConf datasource.SubscriptionEventConfiguration, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + requestContext := getRequestContext(ctx) + hookCtx := &pubSubStreamReceiveEventHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + subscriptionEventConfiguration: subConf, + } + + return fn(hookCtx, evts) + } +} diff --git a/router/go.mod b/router/go.mod index 82ff4f2e73..35b04f0033 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 diff --git a/router/go.sum b/router/go.sum index 1a0bc0afe5..09d82bed26 100644 --- a/router/go.sum +++ b/router/go.sum @@ -321,8 +321,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 h1:PhKYGyTBFM0JIihHLQa6tD5Al6GVFIPuJxi2T+DEiB0= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= diff --git a/router/pkg/pubsub/datasource/datasource.go b/router/pkg/pubsub/datasource/datasource.go index 3a3018b745..b186041388 100644 --- a/router/pkg/pubsub/datasource/datasource.go +++ b/router/pkg/pubsub/datasource/datasource.go @@ -9,7 +9,7 @@ type SubscriptionDataSource interface { SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) - SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) + SetHooks(hooks Hooks) } // EngineDataSourceFactory is the interface that all pubsub data sources must implement. diff --git a/router/pkg/pubsub/datasource/factory.go b/router/pkg/pubsub/datasource/factory.go index 5c42161776..ae25e6dbcf 100644 --- a/router/pkg/pubsub/datasource/factory.go +++ b/router/pkg/pubsub/datasource/factory.go @@ -9,16 +9,18 @@ import ( ) type PlannerConfig[PB ProviderBuilder[P, E], P any, E any] struct { - ProviderBuilder PB - Event E - SubscriptionOnStartFns []SubscriptionOnStartFn + Providers map[string]Provider + ProviderBuilder PB + Event E + Hooks Hooks } -func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E, subscriptionOnStartFns []SubscriptionOnStartFn) *PlannerConfig[PB, P, E] { +func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E, providers map[string]Provider, hooks Hooks) *PlannerConfig[PB, P, E] { return &PlannerConfig[PB, P, E]{ - ProviderBuilder: providerBuilder, - Event: event, - SubscriptionOnStartFns: subscriptionOnStartFns, + Providers: providers, + ProviderBuilder: providerBuilder, + Event: event, + Hooks: hooks, } } diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go new file mode 100644 index 0000000000..abab8b8ef1 --- /dev/null +++ b/router/pkg/pubsub/datasource/hooks.go @@ -0,0 +1,20 @@ +package datasource + +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration) error + +type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent) ([]StreamEvent, error) + +type OnReceiveEventsFn func(ctx context.Context, subConf SubscriptionEventConfiguration, evts []StreamEvent) ([]StreamEvent, error) + +// Hooks contains hooks for the pubsub providers and data sources +type Hooks struct { + SubscriptionOnStart []SubscriptionOnStartFn + OnReceiveEvents []OnReceiveEventsFn + OnPublishEvents []OnPublishEventsFn +} diff --git a/router/pkg/pubsub/datasource/mocks.go b/router/pkg/pubsub/datasource/mocks.go index 861beb3987..3c56f09919 100644 --- a/router/pkg/pubsub/datasource/mocks.go +++ b/router/pkg/pubsub/datasource/mocks.go @@ -556,6 +556,109 @@ func (_c *MockProvider_ID_Call) RunAndReturn(run func() string) *MockProvider_ID return _c } +// Publish provides a mock function for the type MockProvider +func (_mock *MockProvider) Publish(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error { + ret := _mock.Called(ctx, cfg, events) + + if len(ret) == 0 { + panic("no return value specified for Publish") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, PublishEventConfiguration, []StreamEvent) error); ok { + r0 = returnFunc(ctx, cfg, events) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockProvider_Publish_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Publish' +type MockProvider_Publish_Call struct { + *mock.Call +} + +// Publish is a helper method to define mock.On call +// - ctx context.Context +// - cfg PublishEventConfiguration +// - events []StreamEvent +func (_e *MockProvider_Expecter) Publish(ctx interface{}, cfg interface{}, events interface{}) *MockProvider_Publish_Call { + return &MockProvider_Publish_Call{Call: _e.mock.On("Publish", ctx, cfg, events)} +} + +func (_c *MockProvider_Publish_Call) Run(run func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent)) *MockProvider_Publish_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 PublishEventConfiguration + if args[1] != nil { + arg1 = args[1].(PublishEventConfiguration) + } + var arg2 []StreamEvent + if args[2] != nil { + arg2 = args[2].([]StreamEvent) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockProvider_Publish_Call) Return(err error) *MockProvider_Publish_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockProvider_Publish_Call) RunAndReturn(run func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error) *MockProvider_Publish_Call { + _c.Call.Return(run) + return _c +} + +// SetHooks provides a mock function for the type MockProvider +func (_mock *MockProvider) SetHooks(hooks Hooks) { + _mock.Called(hooks) + return +} + +// MockProvider_SetHooks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHooks' +type MockProvider_SetHooks_Call struct { + *mock.Call +} + +// SetHooks is a helper method to define mock.On call +// - hooks Hooks +func (_e *MockProvider_Expecter) SetHooks(hooks interface{}) *MockProvider_SetHooks_Call { + return &MockProvider_SetHooks_Call{Call: _e.mock.On("SetHooks", hooks)} +} + +func (_c *MockProvider_SetHooks_Call) Run(run func(hooks Hooks)) *MockProvider_SetHooks_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 Hooks + if args[0] != nil { + arg0 = args[0].(Hooks) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockProvider_SetHooks_Call) Return() *MockProvider_SetHooks_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProvider_SetHooks_Call) RunAndReturn(run func(hooks Hooks)) *MockProvider_SetHooks_Call { + _c.Run(run) + return _c +} + // Shutdown provides a mock function for the type MockProvider func (_mock *MockProvider) Shutdown(ctx context.Context) error { ret := _mock.Called(ctx) @@ -793,8 +896,8 @@ func (_m *MockProviderBuilder[P, E]) EXPECT() *MockProviderBuilder_Expecter[P, E } // BuildEngineDataSourceFactory provides a mock function for the type MockProviderBuilder -func (_mock *MockProviderBuilder[P, E]) BuildEngineDataSourceFactory(data E) (EngineDataSourceFactory, error) { - ret := _mock.Called(data) +func (_mock *MockProviderBuilder[P, E]) BuildEngineDataSourceFactory(data E, providers map[string]Provider) (EngineDataSourceFactory, error) { + ret := _mock.Called(data, providers) if len(ret) == 0 { panic("no return value specified for BuildEngineDataSourceFactory") @@ -802,18 +905,18 @@ func (_mock *MockProviderBuilder[P, E]) BuildEngineDataSourceFactory(data E) (En var r0 EngineDataSourceFactory var r1 error - if returnFunc, ok := ret.Get(0).(func(E) (EngineDataSourceFactory, error)); ok { - return returnFunc(data) + if returnFunc, ok := ret.Get(0).(func(E, map[string]Provider) (EngineDataSourceFactory, error)); ok { + return returnFunc(data, providers) } - if returnFunc, ok := ret.Get(0).(func(E) EngineDataSourceFactory); ok { - r0 = returnFunc(data) + if returnFunc, ok := ret.Get(0).(func(E, map[string]Provider) EngineDataSourceFactory); ok { + r0 = returnFunc(data, providers) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(EngineDataSourceFactory) } } - if returnFunc, ok := ret.Get(1).(func(E) error); ok { - r1 = returnFunc(data) + if returnFunc, ok := ret.Get(1).(func(E, map[string]Provider) error); ok { + r1 = returnFunc(data, providers) } else { r1 = ret.Error(1) } @@ -827,18 +930,24 @@ type MockProviderBuilder_BuildEngineDataSourceFactory_Call[P any, E any] struct // BuildEngineDataSourceFactory is a helper method to define mock.On call // - data E -func (_e *MockProviderBuilder_Expecter[P, E]) BuildEngineDataSourceFactory(data interface{}) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { - return &MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]{Call: _e.mock.On("BuildEngineDataSourceFactory", data)} +// - providers map[string]Provider +func (_e *MockProviderBuilder_Expecter[P, E]) BuildEngineDataSourceFactory(data interface{}, providers interface{}) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { + return &MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]{Call: _e.mock.On("BuildEngineDataSourceFactory", data, providers)} } -func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) Run(run func(data E)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { +func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) Run(run func(data E, providers map[string]Provider)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { _c.Call.Run(func(args mock.Arguments) { var arg0 E if args[0] != nil { arg0 = args[0].(E) } + var arg1 map[string]Provider + if args[1] != nil { + arg1 = args[1].(map[string]Provider) + } run( arg0, + arg1, ) }) return _c @@ -849,7 +958,7 @@ func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) Return(en return _c } -func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) RunAndReturn(run func(data E) (EngineDataSourceFactory, error)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { +func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) RunAndReturn(run func(data E, providers map[string]Provider) (EngineDataSourceFactory, error)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { _c.Call.Return(run) return _c } @@ -1060,9 +1169,49 @@ func (_c *MockSubscriptionEventUpdater_Complete_Call) RunAndReturn(run func()) * return _c } +// SetHooks provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) SetHooks(hooks Hooks) { + _mock.Called(hooks) + return +} + +// MockSubscriptionEventUpdater_SetHooks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHooks' +type MockSubscriptionEventUpdater_SetHooks_Call struct { + *mock.Call +} + +// SetHooks is a helper method to define mock.On call +// - hooks Hooks +func (_e *MockSubscriptionEventUpdater_Expecter) SetHooks(hooks interface{}) *MockSubscriptionEventUpdater_SetHooks_Call { + return &MockSubscriptionEventUpdater_SetHooks_Call{Call: _e.mock.On("SetHooks", hooks)} +} + +func (_c *MockSubscriptionEventUpdater_SetHooks_Call) Run(run func(hooks Hooks)) *MockSubscriptionEventUpdater_SetHooks_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 Hooks + if args[0] != nil { + arg0 = args[0].(Hooks) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_SetHooks_Call) Return() *MockSubscriptionEventUpdater_SetHooks_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_SetHooks_Call) RunAndReturn(run func(hooks Hooks)) *MockSubscriptionEventUpdater_SetHooks_Call { + _c.Run(run) + return _c +} + // Update provides a mock function for the type MockSubscriptionEventUpdater -func (_mock *MockSubscriptionEventUpdater) Update(event StreamEvent) { - _mock.Called(event) +func (_mock *MockSubscriptionEventUpdater) Update(events []StreamEvent) { + _mock.Called(events) return } @@ -1072,16 +1221,16 @@ type MockSubscriptionEventUpdater_Update_Call struct { } // Update is a helper method to define mock.On call -// - event StreamEvent -func (_e *MockSubscriptionEventUpdater_Expecter) Update(event interface{}) *MockSubscriptionEventUpdater_Update_Call { - return &MockSubscriptionEventUpdater_Update_Call{Call: _e.mock.On("Update", event)} +// - events []StreamEvent +func (_e *MockSubscriptionEventUpdater_Expecter) Update(events interface{}) *MockSubscriptionEventUpdater_Update_Call { + return &MockSubscriptionEventUpdater_Update_Call{Call: _e.mock.On("Update", events)} } -func (_c *MockSubscriptionEventUpdater_Update_Call) Run(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { +func (_c *MockSubscriptionEventUpdater_Update_Call) Run(run func(events []StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 StreamEvent + var arg0 []StreamEvent if args[0] != nil { - arg0 = args[0].(StreamEvent) + arg0 = args[0].([]StreamEvent) } run( arg0, @@ -1095,7 +1244,7 @@ func (_c *MockSubscriptionEventUpdater_Update_Call) Return() *MockSubscriptionEv return _c } -func (_c *MockSubscriptionEventUpdater_Update_Call) RunAndReturn(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { +func (_c *MockSubscriptionEventUpdater_Update_Call) RunAndReturn(run func(events []StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { _c.Run(run) return _c } diff --git a/router/pkg/pubsub/datasource/mocks_resolve.go b/router/pkg/pubsub/datasource/mocks_resolve.go index 3efc24b405..19bad89c16 100644 --- a/router/pkg/pubsub/datasource/mocks_resolve.go +++ b/router/pkg/pubsub/datasource/mocks_resolve.go @@ -5,6 +5,8 @@ package datasource import ( + "context" + mock "github.com/stretchr/testify/mock" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -76,6 +78,52 @@ func (_c *MockSubscriptionUpdater_Close_Call) RunAndReturn(run func(kind resolve return _c } +// CloseSubscription provides a mock function for the type MockSubscriptionUpdater +func (_mock *MockSubscriptionUpdater) CloseSubscription(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier) { + _mock.Called(kind, id) + return +} + +// MockSubscriptionUpdater_CloseSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSubscription' +type MockSubscriptionUpdater_CloseSubscription_Call struct { + *mock.Call +} + +// CloseSubscription is a helper method to define mock.On call +// - kind resolve.SubscriptionCloseKind +// - id resolve.SubscriptionIdentifier +func (_e *MockSubscriptionUpdater_Expecter) CloseSubscription(kind interface{}, id interface{}) *MockSubscriptionUpdater_CloseSubscription_Call { + return &MockSubscriptionUpdater_CloseSubscription_Call{Call: _e.mock.On("CloseSubscription", kind, id)} +} + +func (_c *MockSubscriptionUpdater_CloseSubscription_Call) Run(run func(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier)) *MockSubscriptionUpdater_CloseSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 resolve.SubscriptionCloseKind + if args[0] != nil { + arg0 = args[0].(resolve.SubscriptionCloseKind) + } + var arg1 resolve.SubscriptionIdentifier + if args[1] != nil { + arg1 = args[1].(resolve.SubscriptionIdentifier) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockSubscriptionUpdater_CloseSubscription_Call) Return() *MockSubscriptionUpdater_CloseSubscription_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionUpdater_CloseSubscription_Call) RunAndReturn(run func(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier)) *MockSubscriptionUpdater_CloseSubscription_Call { + _c.Run(run) + return _c +} + // Complete provides a mock function for the type MockSubscriptionUpdater func (_mock *MockSubscriptionUpdater) Complete() { _mock.Called() @@ -109,6 +157,52 @@ func (_c *MockSubscriptionUpdater_Complete_Call) RunAndReturn(run func()) *MockS return _c } +// Subscriptions provides a mock function for the type MockSubscriptionUpdater +func (_mock *MockSubscriptionUpdater) Subscriptions() map[context.Context]resolve.SubscriptionIdentifier { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Subscriptions") + } + + var r0 map[context.Context]resolve.SubscriptionIdentifier + if returnFunc, ok := ret.Get(0).(func() map[context.Context]resolve.SubscriptionIdentifier); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[context.Context]resolve.SubscriptionIdentifier) + } + } + return r0 +} + +// MockSubscriptionUpdater_Subscriptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscriptions' +type MockSubscriptionUpdater_Subscriptions_Call struct { + *mock.Call +} + +// Subscriptions is a helper method to define mock.On call +func (_e *MockSubscriptionUpdater_Expecter) Subscriptions() *MockSubscriptionUpdater_Subscriptions_Call { + return &MockSubscriptionUpdater_Subscriptions_Call{Call: _e.mock.On("Subscriptions")} +} + +func (_c *MockSubscriptionUpdater_Subscriptions_Call) Run(run func()) *MockSubscriptionUpdater_Subscriptions_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSubscriptionUpdater_Subscriptions_Call) Return(contextToSubscriptionIdentifier map[context.Context]resolve.SubscriptionIdentifier) *MockSubscriptionUpdater_Subscriptions_Call { + _c.Call.Return(contextToSubscriptionIdentifier) + return _c +} + +func (_c *MockSubscriptionUpdater_Subscriptions_Call) RunAndReturn(run func() map[context.Context]resolve.SubscriptionIdentifier) *MockSubscriptionUpdater_Subscriptions_Call { + _c.Call.Return(run) + return _c +} + // Update provides a mock function for the type MockSubscriptionUpdater func (_mock *MockSubscriptionUpdater) Update(data []byte) { _mock.Called(data) @@ -148,3 +242,49 @@ func (_c *MockSubscriptionUpdater_Update_Call) RunAndReturn(run func(data []byte _c.Run(run) return _c } + +// UpdateSubscription provides a mock function for the type MockSubscriptionUpdater +func (_mock *MockSubscriptionUpdater) UpdateSubscription(id resolve.SubscriptionIdentifier, data []byte) { + _mock.Called(id, data) + return +} + +// MockSubscriptionUpdater_UpdateSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateSubscription' +type MockSubscriptionUpdater_UpdateSubscription_Call struct { + *mock.Call +} + +// UpdateSubscription is a helper method to define mock.On call +// - id resolve.SubscriptionIdentifier +// - data []byte +func (_e *MockSubscriptionUpdater_Expecter) UpdateSubscription(id interface{}, data interface{}) *MockSubscriptionUpdater_UpdateSubscription_Call { + return &MockSubscriptionUpdater_UpdateSubscription_Call{Call: _e.mock.On("UpdateSubscription", id, data)} +} + +func (_c *MockSubscriptionUpdater_UpdateSubscription_Call) Run(run func(id resolve.SubscriptionIdentifier, data []byte)) *MockSubscriptionUpdater_UpdateSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 resolve.SubscriptionIdentifier + if args[0] != nil { + arg0 = args[0].(resolve.SubscriptionIdentifier) + } + var arg1 []byte + if args[1] != nil { + arg1 = args[1].([]byte) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockSubscriptionUpdater_UpdateSubscription_Call) Return() *MockSubscriptionUpdater_UpdateSubscription_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionUpdater_UpdateSubscription_Call) RunAndReturn(run func(id resolve.SubscriptionIdentifier, data []byte)) *MockSubscriptionUpdater_UpdateSubscription_Call { + _c.Run(run) + return _c +} diff --git a/router/pkg/pubsub/datasource/planner.go b/router/pkg/pubsub/datasource/planner.go index a480f8270e..f0378caa88 100644 --- a/router/pkg/pubsub/datasource/planner.go +++ b/router/pkg/pubsub/datasource/planner.go @@ -48,7 +48,7 @@ func (p *Planner[PB, P, E]) ConfigureFetch() resolve.FetchConfiguration { return resolve.FetchConfiguration{} } - pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event) + pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event, p.config.Providers) if err != nil { p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to build data source: %w", err)) return resolve.FetchConfiguration{} @@ -93,7 +93,7 @@ func (p *Planner[PB, P, E]) ConfigureSubscription() plan.SubscriptionConfigurati return plan.SubscriptionConfiguration{} } - pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event) + pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event, p.config.Providers) if err != nil { p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription: %w", err)) return plan.SubscriptionConfiguration{} @@ -109,7 +109,7 @@ func (p *Planner[PB, P, E]) ConfigureSubscription() plan.SubscriptionConfigurati p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription: %w", err)) return plan.SubscriptionConfiguration{} } - dataSource.SetSubscriptionOnStartFns(p.config.SubscriptionOnStartFns...) + dataSource.SetHooks(p.config.Hooks) input, err := pubSubDataSource.ResolveDataSourceSubscriptionInput() if err != nil { diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index 33cac33782..57bbb70ed7 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -4,7 +4,6 @@ import ( "context" "github.com/wundergraph/cosmo/router/pkg/metric" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) type ArgumentTemplateCallback func(tpl string) (string, error) @@ -23,6 +22,7 @@ type Lifecycle interface { type Adapter interface { Lifecycle Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error + Publish(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error } // Provider is the interface that the PubSub provider must implement @@ -32,6 +32,8 @@ type Provider interface { ID() string // TypeID Get the provider type id (e.g. "kafka", "nats") TypeID() string + // SetHooks Set the hooks + SetHooks(Hooks) } // ProviderBuilder is the interface that the provider builder must implement. @@ -41,7 +43,7 @@ type ProviderBuilder[P, E any] interface { // BuildProvider Build the provider and the adapter BuildProvider(options P, providerOpts ProviderOpts) (Provider, error) // BuildEngineDataSourceFactory Build the data source for the given provider and event configuration - BuildEngineDataSourceFactory(data E) (EngineDataSourceFactory, error) + BuildEngineDataSourceFactory(data E, providers map[string]Provider) (EngineDataSourceFactory, error) } // ProviderType represents the type of pubsub provider @@ -58,10 +60,9 @@ const ( // there could be other common fields in the future, but for now we only have data type StreamEvent interface { GetData() []byte + Clone() StreamEvent } -type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration) error - // SubscriptionEventConfiguration is the interface that all subscription event configurations must implement type SubscriptionEventConfiguration interface { ProviderID() string diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 84561b06db..e234ebfb73 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -11,6 +11,21 @@ type PubSubProvider struct { typeID string Adapter Adapter Logger *zap.Logger + hooks Hooks +} + +// applyPublishEventHooks processes events through a chain of hook functions +// Each hook receives the result from the previous hook, creating a proper middleware pipeline +func applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, hooks []OnPublishEventsFn) ([]StreamEvent, error) { + currentEvents := events + for _, hook := range hooks { + var err error + currentEvents, err = hook(ctx, cfg, currentEvents) + if err != nil { + return currentEvents, err + } + } + return currentEvents, nil } func (p *PubSubProvider) ID() string { @@ -39,6 +54,34 @@ func (p *PubSubProvider) Subscribe(ctx context.Context, cfg SubscriptionEventCon return p.Adapter.Subscribe(ctx, cfg, updater) } +func (p *PubSubProvider) Publish(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error { + if len(p.hooks.OnPublishEvents) == 0 { + return p.Adapter.Publish(ctx, cfg, events) + } + + processedEvents, hooksErr := applyPublishEventHooks(ctx, cfg, events, p.hooks.OnPublishEvents) + if hooksErr != nil { + p.Logger.Error( + "error applying publish event hooks", + zap.Error(hooksErr), + zap.String("provider_id", cfg.ProviderID()), + zap.String("provider_type_id", string(cfg.ProviderType())), + zap.String("field_name", cfg.RootFieldName()), + ) + } + + errPublish := p.Adapter.Publish(ctx, cfg, processedEvents) + if errPublish != nil { + return errPublish + } + + return hooksErr +} + +func (p *PubSubProvider) SetHooks(hooks Hooks) { + p.hooks = hooks +} + func NewPubSubProvider(id string, typeID string, adapter Adapter, logger *zap.Logger) *PubSubProvider { return &PubSubProvider{ id: id, diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 134bfbd6bb..6ef41c56a5 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -1,14 +1,67 @@ package datasource import ( + "bytes" "context" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.uber.org/zap" ) +// Test helper types +type testEvent struct { + data []byte +} + +func (e *testEvent) GetData() []byte { + return e.data +} + +func (e *testEvent) Clone() StreamEvent { + return &testEvent{ + data: bytes.Clone(e.data), + } +} + +type testSubscriptionConfig struct { + providerID string + providerType ProviderType + fieldName string +} + +func (c *testSubscriptionConfig) ProviderID() string { + return c.providerID +} + +func (c *testSubscriptionConfig) ProviderType() ProviderType { + return c.providerType +} + +func (c *testSubscriptionConfig) RootFieldName() string { + return c.fieldName +} + +type testPublishConfig struct { + providerID string + providerType ProviderType + fieldName string +} + +func (c *testPublishConfig) ProviderID() string { + return c.providerID +} + +func (c *testPublishConfig) ProviderType() ProviderType { + return c.providerType +} + +func (c *testPublishConfig) RootFieldName() string { + return c.fieldName +} + func TestProvider_Startup_Success(t *testing.T) { mockAdapter := NewMockProvider(t) mockAdapter.On("Startup", mock.Anything).Return(nil) @@ -57,18 +110,375 @@ func TestProvider_Shutdown_Error(t *testing.T) { assert.Error(t, err) } -func TestProvider_ID(t *testing.T) { - const testID = "test-id" +func TestProvider_Subscribe_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + mockUpdater := NewMockSubscriptionEventUpdater(t) + config := &testSubscriptionConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + mockAdapter.On("Subscribe", mock.Anything, config, mockUpdater).Return(nil) + provider := PubSubProvider{ - id: testID, + Adapter: mockAdapter, } - assert.Equal(t, testID, provider.ID()) + err := provider.Subscribe(context.Background(), config, mockUpdater) + + assert.NoError(t, err) } -func TestProvider_TypeID(t *testing.T) { - const providerTypeID = "test-type-id" +func TestProvider_Subscribe_Error(t *testing.T) { + mockAdapter := NewMockProvider(t) + mockUpdater := NewMockSubscriptionEventUpdater(t) + config := &testSubscriptionConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + expectedError := errors.New("subscription error") + + mockAdapter.On("Subscribe", mock.Anything, config, mockUpdater).Return(expectedError) + provider := PubSubProvider{ - typeID: providerTypeID, + Adapter: mockAdapter, } - assert.Equal(t, providerTypeID, provider.TypeID()) + err := provider.Subscribe(context.Background(), config, mockUpdater) + + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestProvider_Publish_NoHooks_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data 1")}, + &testEvent{data: []byte("test data 2")}, + } + + mockAdapter.On("Publish", mock.Anything, config, events).Return(nil) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{}, // No hooks + } + err := provider.Publish(context.Background(), config, events) + + assert.NoError(t, err) +} + +func TestProvider_Publish_NoHooks_Error(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + expectedError := errors.New("publish error") + + mockAdapter.On("Publish", mock.Anything, config, events).Return(expectedError) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{}, // No hooks + } + err := provider.Publish(context.Background(), config, events) + + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestProvider_Publish_WithHooks_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original data")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified data")}, + } + + // Define hook that modifies events + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return modifiedEvents, nil + } + + mockAdapter.On("Publish", mock.Anything, config, modifiedEvents).Return(nil) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + }, + } + err := provider.Publish(context.Background(), config, originalEvents) + + assert.NoError(t, err) +} + +func TestProvider_Publish_WithHooks_HookError(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + hookError := errors.New("hook processing error") + + // Define hook that returns an error + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + mockAdapter.On("Publish", mock.Anything, config, []StreamEvent(nil)).Return(nil) + + // Should call Publish on adapter also if hook fails + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + }, + Logger: zap.NewNop(), + } + err := provider.Publish(context.Background(), config, events) + + assert.Error(t, err) + assert.Equal(t, hookError, err) +} + +func TestProvider_Publish_WithHooks_AdapterError(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original data")}, + } + processedEvents := []StreamEvent{ + &testEvent{data: []byte("processed data")}, + } + adapterError := errors.New("adapter publish error") + + // Define hook that processes events successfully + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return processedEvents, nil + } + + mockAdapter.On("Publish", mock.Anything, config, processedEvents).Return(adapterError) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + }, + } + err := provider.Publish(context.Background(), config, originalEvents) + + assert.Error(t, err) + assert.Equal(t, adapterError, err) +} + +func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + // Chain of hooks that modify the data + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + } + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + } + + mockAdapter.On("Publish", mock.Anything, config, mock.MatchedBy(func(events []StreamEvent) bool { + return len(events) == 1 && string(events[0].GetData()) == "modified by hook2" + })).Return(nil) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook1, hook2}, + }, + } + err := provider.Publish(context.Background(), config, originalEvents) + + assert.NoError(t, err) +} + +func TestProvider_SetHooks(t *testing.T) { + provider := &PubSubProvider{} + + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, nil + } + + hooks := Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + } + + provider.SetHooks(hooks) + + assert.Equal(t, hooks, provider.hooks) +} + +func TestNewPubSubProvider(t *testing.T) { + mockAdapter := NewMockProvider(t) + logger := zap.NewNop() + id := "test-provider-id" + typeID := "test-type-id" + + provider := NewPubSubProvider(id, typeID, mockAdapter, logger) + + assert.NotNil(t, provider) + assert.Equal(t, id, provider.ID()) + assert.Equal(t, typeID, provider.TypeID()) + assert.Equal(t, mockAdapter, provider.Adapter) + assert.Equal(t, logger, provider.Logger) + assert.Empty(t, provider.hooks.OnPublishEvents) +} + +func TestApplyPublishEventHooks_NoHooks(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{}) + + assert.NoError(t, err) + assert.Equal(t, originalEvents, result) +} + +func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified")}, + } + + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return modifiedEvents, nil + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook}) + + assert.NoError(t, err) + assert.Equal(t, modifiedEvents, result) +} + +func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + hookError := errors.New("hook processing failed") + + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook}) + + assert.Error(t, err) + assert.Equal(t, hookError, err) + assert.Nil(t, result) +} + +func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + } + hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) + + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "final", string(result[0].GetData())) +} + +func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + middleHookError := errors.New("middle hook failed") + + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, middleHookError + } + hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) + + assert.Error(t, err) + assert.Equal(t, middleHookError, err) + assert.Nil(t, result) } diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index e5c9c26ab6..16ec03171a 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -6,6 +6,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error @@ -13,9 +14,10 @@ type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Dige // PubSubSubscriptionDataSource is a data source for handling subscriptions using a Pub/Sub mechanism. // It implements the SubscriptionDataSource interface and HookableSubscriptionDataSource type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { - pubSub Adapter - uniqueRequestID uniqueRequestIdFn - subscriptionOnStartFns []SubscriptionOnStartFn + pubSub Adapter + uniqueRequestID uniqueRequestIdFn + hooks Hooks + logger *zap.Logger } func (s *PubSubSubscriptionDataSource[C]) SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) { @@ -39,11 +41,11 @@ func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []by return errors.New("invalid subscription configuration") } - return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(updater)) + return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(conf, s.hooks, updater, s.logger)) } func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { - for _, fn := range s.subscriptionOnStartFns { + for _, fn := range s.hooks.SubscriptionOnStart { conf, errConf := s.SubscriptionEventConfiguration(input) if errConf != nil { return err @@ -57,16 +59,20 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.Startu return nil } -func (s *PubSubSubscriptionDataSource[C]) SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) { - s.subscriptionOnStartFns = append(s.subscriptionOnStartFns, fns...) +func (s *PubSubSubscriptionDataSource[C]) SetHooks(hooks Hooks) { + s.hooks = hooks } var _ SubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) var _ resolve.HookableSubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) -func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn) *PubSubSubscriptionDataSource[C] { +func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn, logger *zap.Logger) *PubSubSubscriptionDataSource[C] { + if logger == nil { + logger = zap.NewNop() + } return &PubSubSubscriptionDataSource[C]{ pubSub: pubSub, uniqueRequestID: uniqueRequestIdFn, + logger: logger, } } diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index a9170d7edd..c82b339faa 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) // testSubscriptionEventConfiguration implements SubscriptionEventConfiguration for testing @@ -36,7 +37,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_Success(t * return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -61,7 +62,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_InvalidJSON return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) invalidInput := []byte(`{"invalid": json}`) result, err := dataSource.SubscriptionEventConfiguration(invalidInput) @@ -75,7 +76,7 @@ func TestPubSubSubscriptionDataSource_UniqueRequestID_Success(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) ctx := &resolve.Context{} input := []byte(`{"test": "data"}`) @@ -92,7 +93,7 @@ func TestPubSubSubscriptionDataSource_UniqueRequestID_Error(t *testing.T) { return expectedError } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) ctx := &resolve.Context{} input := []byte(`{"test": "data"}`) @@ -109,7 +110,7 @@ func TestPubSubSubscriptionDataSource_Start_Success(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -134,7 +135,7 @@ func TestPubSubSubscriptionDataSource_Start_NoConfiguration(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) invalidInput := []byte(`{"invalid": json}`) ctx := resolve.NewContext(context.Background()) @@ -151,7 +152,7 @@ func TestPubSubSubscriptionDataSource_Start_SubscribeError(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -178,7 +179,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_Success(t *testing.T) return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -202,7 +203,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) // Add subscription start hooks hook1Called := false @@ -218,7 +219,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T return nil } - dataSource.SetSubscriptionOnStartFns(hook1, hook2) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook1, hook2}, + }) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -238,13 +241,46 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T assert.True(t, hook2Called) } +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsClose(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + + // Add hook that returns close=true + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return nil + } + + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + }) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + errSubStart := dataSource.SubscriptionOnStart(ctx, input) + assert.NoError(t, errSubStart) +} + func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *testing.T) { mockAdapter := NewMockProvider(t) uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) expectedError := errors.New("hook error") // Add hook that returns an error @@ -252,7 +288,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *te return expectedError } - dataSource.SetSubscriptionOnStartFns(hook) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + }) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -277,10 +315,10 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) // Initially should have no hooks - assert.Len(t, dataSource.subscriptionOnStartFns, 0) + assert.Len(t, dataSource.hooks.SubscriptionOnStart, 0) // Add hooks hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { @@ -290,11 +328,15 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { return nil } - dataSource.SetSubscriptionOnStartFns(hook1) - assert.Len(t, dataSource.subscriptionOnStartFns, 1) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook1}, + }) + assert.Len(t, dataSource.hooks.SubscriptionOnStart, 1) - dataSource.SetSubscriptionOnStartFns(hook2) - assert.Len(t, dataSource.subscriptionOnStartFns, 2) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook2}, + }) + assert.Len(t, dataSource.hooks.SubscriptionOnStart, 1) } func TestNewPubSubSubscriptionDataSource(t *testing.T) { @@ -303,12 +345,12 @@ func TestNewPubSubSubscriptionDataSource(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) assert.NotNil(t, dataSource) assert.Equal(t, mockAdapter, dataSource.pubSub) assert.NotNil(t, dataSource.uniqueRequestID) - assert.Empty(t, dataSource.subscriptionOnStartFns) + assert.Empty(t, dataSource.hooks.SubscriptionOnStart) } func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { @@ -317,7 +359,7 @@ func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) // Test that it implements SubscriptionDataSource interface var _ SubscriptionDataSource = dataSource diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 9332d10f7a..95289bb313 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -1,34 +1,112 @@ package datasource -import "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" +) // SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface // that provides a way to send the event struct instead of the raw data // It is used to give access to the event additional fields to the hooks. type SubscriptionEventUpdater interface { - Update(event StreamEvent) + Update(events []StreamEvent) Complete() Close(kind resolve.SubscriptionCloseKind) + SetHooks(hooks Hooks) } type subscriptionEventUpdater struct { - eventUpdater resolve.SubscriptionUpdater + eventUpdater resolve.SubscriptionUpdater + subscriptionEventConfiguration SubscriptionEventConfiguration + hooks Hooks + logger *zap.Logger +} + +func (s *subscriptionEventUpdater) Update(events []StreamEvent) { + if len(s.hooks.OnReceiveEvents) == 0 { + for _, event := range events { + s.eventUpdater.Update(event.GetData()) + } + return + } + // If there are hooks, we should apply them separated for each subscription + for ctx, subId := range s.eventUpdater.Subscriptions() { + processedEvents, err := applyStreamEventHooks( + ctx, + s.subscriptionEventConfiguration, + events, + s.hooks.OnReceiveEvents, + ) + // updates the events even if the hooks fail + // if a hook doesn't want to send the events, it should return no events! + for _, event := range processedEvents { + s.eventUpdater.UpdateSubscription(subId, event.GetData()) + } + if err != nil { + // For all errors, log them + if s.logger != nil { + s.logger.Error( + "An error occurred while processing stream events hooks", + zap.Error(err), + zap.String("provider_type", string(s.subscriptionEventConfiguration.ProviderType())), + zap.String("provider_id", s.subscriptionEventConfiguration.ProviderID()), + zap.String("field_name", s.subscriptionEventConfiguration.RootFieldName()), + ) + } + // Always close the subscription when a hook reports an error to avoid inconsistent state. + s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindNormal, subId) + } + } } -func (h *subscriptionEventUpdater) Update(event StreamEvent) { - h.eventUpdater.Update(event.GetData()) +func (s *subscriptionEventUpdater) Complete() { + s.eventUpdater.Complete() } -func (h *subscriptionEventUpdater) Complete() { - h.eventUpdater.Complete() +func (s *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { + s.eventUpdater.Close(kind) } -func (h *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { - h.eventUpdater.Close(kind) +func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { + s.hooks = hooks +} + +// applyStreamEventHooks processes events through a chain of hook functions +// Each hook receives the result from the previous hook, creating a proper middleware pipeline +func applyStreamEventHooks( + ctx context.Context, + cfg SubscriptionEventConfiguration, + events []StreamEvent, + hooks []OnReceiveEventsFn) ([]StreamEvent, error) { + // Copy the events to avoid modifying the original slice + currentEvents := make([]StreamEvent, len(events)) + for i, event := range events { + currentEvents[i] = event.Clone() + } + // Apply each hook in sequence, passing the result of one as the input to the next + // If any hook returns an error, stop processing and return the error + for _, hook := range hooks { + var err error + currentEvents, err = hook(ctx, cfg, currentEvents) + if err != nil { + return currentEvents, err + } + } + return currentEvents, nil } -func NewSubscriptionEventUpdater(eventUpdater resolve.SubscriptionUpdater) SubscriptionEventUpdater { +func NewSubscriptionEventUpdater( + cfg SubscriptionEventConfiguration, + hooks Hooks, + eventUpdater resolve.SubscriptionUpdater, + logger *zap.Logger, +) SubscriptionEventUpdater { return &subscriptionEventUpdater{ - eventUpdater: eventUpdater, + subscriptionEventConfiguration: cfg, + hooks: hooks, + eventUpdater: eventUpdater, + logger: logger, } } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go new file mode 100644 index 0000000000..79fd140a51 --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -0,0 +1,627 @@ +package datasource + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +// Test helper type for subscription event configuration +type testSubscriptionEventConfig struct { + providerID string + providerType ProviderType + fieldName string +} + +func (c *testSubscriptionEventConfig) ProviderID() string { + return c.providerID +} + +func (c *testSubscriptionEventConfig) ProviderType() ProviderType { + return c.providerType +} + +func (c *testSubscriptionEventConfig) RootFieldName() string { + return c.fieldName +} + +type receivedHooksArgs struct { + events []StreamEvent + cfg SubscriptionEventConfiguration +} + +func TestSubscriptionEventUpdater_Update_NoHooks(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data 1")}, + &testEvent{data: []byte("test data 2")}, + } + + // Expect calls to Update for each event + mockUpdater.On("Update", []byte("test data 1")).Return() + mockUpdater.On("Update", []byte("test data 2")).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, // No hooks + } + + updater.Update(events) +} + +func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original data")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified data")}, + } + + // Create wrapper function for the mock + receivedArgs := make(chan receivedHooksArgs, 1) + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs <- receivedHooksArgs{events: events, cfg: cfg} + return modifiedEvents, nil + } + + // Expect call to UpdateSubscription with modified data + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("UpdateSubscription", subId, []byte("modified data")).Return() + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + } + + updater.Update(originalEvents) + + select { + case receivedArgs := <-receivedArgs: + assert.Equal(t, originalEvents, receivedArgs.events) + assert.Equal(t, config, receivedArgs.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } +} + +func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + hookError := errors.New("hook processing error") + + // Define hook that returns an error + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + // Expect call to UpdateSubscription with modified data + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + // Should not call Update or UpdateSubscription on eventUpdater since hook fails + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + } + + updater.Update(events) + + // Assert that Update and UpdateSubscription were not called on the eventUpdater + mockUpdater.AssertNotCalled(t, "Update") + mockUpdater.AssertNotCalled(t, "UpdateSubscription") + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) +} + +func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + // Chain of hooks that modify the data + receivedArgs1 := make(chan receivedHooksArgs, 1) + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + } + + receivedArgs2 := make(chan receivedHooksArgs, 1) + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + } + + // Expect call to UpdateSubscription with modified data + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("UpdateSubscription", subId, []byte("modified by hook2")).Return() + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook1, hook2}, + }, + } + + updater.Update(originalEvents) + + select { + case receivedArgs1 := <-receivedArgs1: + assert.Equal(t, originalEvents, receivedArgs1.events) + assert.Equal(t, config, receivedArgs1.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs2 := <-receivedArgs2: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, receivedArgs2.events) + assert.Equal(t, config, receivedArgs2.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } +} + +func TestSubscriptionEventUpdater_Complete(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + mockUpdater.On("Complete").Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.Complete() +} + +func TestSubscriptionEventUpdater_Close(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + closeKind := resolve.SubscriptionCloseKindNormal + + mockUpdater.On("Close", closeKind).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.Close(closeKind) +} + +func TestSubscriptionEventUpdater_SetHooks(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, nil + } + + hooks := Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + } + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.SetHooks(hooks) + + assert.Equal(t, hooks, updater.hooks) +} + +func TestNewSubscriptionEventUpdater(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, nil + } + + hooks := Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + } + + updater := NewSubscriptionEventUpdater(config, hooks, mockUpdater, zap.NewNop()) + + assert.NotNil(t, updater) + + // Type assertion to access private fields for testing + var concreteUpdater *subscriptionEventUpdater + assert.IsType(t, concreteUpdater, updater) + concreteUpdater = updater.(*subscriptionEventUpdater) + assert.Equal(t, config, concreteUpdater.subscriptionEventConfiguration) + assert.Equal(t, hooks, concreteUpdater.hooks) + assert.Equal(t, mockUpdater, concreteUpdater.eventUpdater) +} + +func TestApplyStreamEventHooks_NoHooks(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) + + assert.NoError(t, err) + assert.Equal(t, originalEvents, result) +} + +func TestApplyStreamEventHooks_SingleHook_Success(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified")}, + } + + hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return modifiedEvents, nil + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + + assert.NoError(t, err) + assert.Equal(t, modifiedEvents, result) +} + +func TestApplyStreamEventHooks_SingleHook_Error(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + hookError := errors.New("hook processing failed") + + hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + + assert.Error(t, err) + assert.Equal(t, hookError, err) + assert.Nil(t, result) +} + +func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + receivedArgs1 := make(chan receivedHooksArgs, 1) + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + receivedArgs2 := make(chan receivedHooksArgs, 1) + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + } + receivedArgs3 := make(chan receivedHooksArgs, 1) + hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + + select { + case receivedArgs1 := <-receivedArgs1: + assert.Equal(t, originalEvents, receivedArgs1.events) + assert.Equal(t, config, receivedArgs1.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs2 := <-receivedArgs2: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) + assert.Equal(t, config, receivedArgs2.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs3 := <-receivedArgs3: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step2")}}, receivedArgs3.events) + assert.Equal(t, config, receivedArgs3.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "final", string(result[0].GetData())) +} + +func TestApplyStreamEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + middleHookError := errors.New("middle hook failed") + + receivedArgs1 := make(chan receivedHooksArgs, 1) + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + receivedArgs2 := make(chan receivedHooksArgs, 1) + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} + return nil, middleHookError + } + receivedArgs3 := make(chan receivedHooksArgs, 1) + hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + + assert.Error(t, err) + assert.Equal(t, middleHookError, err) + assert.Nil(t, result) + + select { + case receivedArgs1 := <-receivedArgs1: + assert.Equal(t, originalEvents, receivedArgs1.events) + assert.Equal(t, config, receivedArgs1.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs2 := <-receivedArgs2: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) + assert.Equal(t, config, receivedArgs2.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + assert.Empty(t, receivedArgs3) +} + +// Test the updateEvents method indirectly through Update method +func TestSubscriptionEventUpdater_UpdateEvents_EmptyEvents(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{} // Empty events + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, // No hooks + } + + updater.Update(events) + + // No calls to Update should be made for empty events + mockUpdater.AssertNotCalled(t, "Update") +} + +func TestSubscriptionEventUpdater_Close_WithDifferentCloseKinds(t *testing.T) { + testCases := []struct { + name string + closeKind resolve.SubscriptionCloseKind + }{ + {"Normal", resolve.SubscriptionCloseKindNormal}, + {"DownstreamServiceError", resolve.SubscriptionCloseKindDownstreamServiceError}, + {"GoingAway", resolve.SubscriptionCloseKindGoingAway}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + mockUpdater.On("Close", tc.closeKind).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.Close(tc.closeKind) + }) + } +} + +func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscription(t *testing.T) { + testCases := []struct { + name string + hookError error + }{ + { + name: "generic error", + hookError: errors.New("subscription should close"), + }, + { + name: "error implementing CloseSubscription false", + hookError: errors.New("subscription should still close"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, tc.hookError + } + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + } + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("UpdateSubscription", subId, []byte("test data")).Return() + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + updater.Update(events) + + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) + }) + } +} + +func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWritesError(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + hookError := errors.New("hook processing error") + + // Define hook that returns an error + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + zCore, logObserver := observer.New(zap.InfoLevel) + logger := zap.New(zCore) + + // Test with a real zap logger to verify error logging behavior + // The logger.Error() call should be executed when an error occurs + updater := NewSubscriptionEventUpdater(config, Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, mockUpdater, logger) + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + updater.Update(events) + + // Assert that Update was not called on the eventUpdater + mockUpdater.AssertNotCalled(t, "UpdateSubscription") + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) + + msgs := logObserver.FilterMessageSnippet("An error occurred while processing stream events hooks").TakeAll() + assert.Equal(t, 1, len(msgs)) +} diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index fa906370ab..7f61a242b9 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -13,6 +13,7 @@ import ( "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kgo" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -20,19 +21,14 @@ var ( errClientClosed = errors.New("client closed") ) +// Ensure ProviderAdapter implements Adapter +var _ datasource.Adapter = (*ProviderAdapter)(nil) + const ( kafkaReceive = "receive" kafkaProduce = "produce" ) -// Adapter defines the interface for Kafka adapter operations -type Adapter interface { - Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error - Publish(ctx context.Context, event PublishEventConfiguration) error - Startup(ctx context.Context) error - Shutdown(ctx context.Context) error -} - // ProviderAdapter is a Kafka pubsub implementation. // It uses the franz-go Kafka client to consume and produce messages. // The pubsub is stateless and does not store any messages. @@ -112,12 +108,11 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u DestinationName: r.Topic, }) - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: r.Value, Headers: headers, Key: r.Key, - }) - + }}) } } } @@ -132,7 +127,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri } log := p.logger.With( - zap.String("provider_id", subConf.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "subscribe"), zap.Strings("topics", subConf.Topics), ) @@ -159,15 +154,24 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri go func() { - defer p.closeWg.Done() + defer func() { + client.Close() + updater.Close(resolve.SubscriptionCloseKindNormal) + p.closeWg.Done() + }() err := p.topicPoller(ctx, client, updater, PollerOpts{providerId: conf.ProviderID()}) if err != nil { if errors.Is(err, errClientClosed) || errors.Is(err, context.Canceled) { log.Debug("poller canceled", zap.Error(err)) } else { - log.Error("poller error", zap.Error(err)) - + log.Error( + "poller error", + zap.Error(err), + zap.String("provider_id", conf.ProviderID()), + zap.String("provider_type", string(conf.ProviderType())), + zap.String("field_name", conf.RootFieldName()), + ) } return } @@ -176,67 +180,85 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri return nil } -// Publish publishes the given event to the Kafka topic in a non-blocking way. +// Publish publishes the given events to the Kafka topic in a non-blocking way. // Publish errors are logged and returned as a pubsub error. -// The event is written with a dedicated write client. -func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { +// The events are written with a dedicated write client. +func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + pubConf, ok := conf.(*PublishEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "publish"), - zap.String("topic", event.Topic), + zap.String("topic", pubConf.Topic), ) if p.writeClient == nil { return datasource.NewError("kafka write client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Event.Data)) + if len(events) == 0 { + return nil + } + + log.Debug("publish", zap.Int("event_count", len(events))) var wg sync.WaitGroup - wg.Add(1) + wg.Add(len(events)) var pErr error + var errMutex sync.Mutex - headers := make([]kgo.RecordHeader, 0, len(event.Event.Headers)) - for key, value := range event.Event.Headers { - headers = append(headers, kgo.RecordHeader{ - Key: key, - Value: value, - }) - } + for _, streamEvent := range events { + kafkaEvent, ok := streamEvent.(*Event) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } - p.writeClient.Produce(ctx, &kgo.Record{ - Key: event.Event.Key, - Topic: event.Topic, - Value: event.Event.Data, - Headers: headers, - }, func(record *kgo.Record, err error) { - defer wg.Done() - if err != nil { - pErr = err + headers := make([]kgo.RecordHeader, 0, len(kafkaEvent.Headers)) + for key, value := range kafkaEvent.Headers { + headers = append(headers, kgo.RecordHeader{ + Key: key, + Value: value, + }) } - }) + + p.writeClient.Produce(ctx, &kgo.Record{ + Key: kafkaEvent.Key, + Topic: pubConf.Topic, + Value: kafkaEvent.Data, + Headers: headers, + }, func(record *kgo.Record, err error) { + defer wg.Done() + if err != nil { + errMutex.Lock() + pErr = err + errMutex.Unlock() + } + }) + } wg.Wait() if pErr != nil { log.Error("publish error", zap.Error(pErr)) - // failure emission: include error.type generic p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: pubConf.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, ErrorType: "publish_error", - DestinationName: event.Topic, + DestinationName: pubConf.Topic, }) - return datasource.NewError(fmt.Sprintf("error publishing to Kafka topic %s", event.Topic), pErr) + return datasource.NewError(fmt.Sprintf("error publishing to Kafka topic %s", pubConf.Topic), pErr) } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: pubConf.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, - DestinationName: event.Topic, + DestinationName: pubConf.Topic, }) return nil } diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 723c0d0bd0..00a38023ea 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -6,9 +6,13 @@ import ( "encoding/json" "fmt" "io" + "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // Event represents an event from Kafka @@ -22,6 +26,18 @@ func (e *Event) GetData() []byte { return e.Data } +func (e *Event) Clone() datasource.StreamEvent { + e2 := *e + e2.Data = slices.Clone(e.Data) + e2.Headers = make(map[string][]byte, len(e.Headers)) + for k, v := range e.Headers { + e2.Headers[k] = slices.Clone(v) + } + return &e2 +} + +// SubscriptionEventConfiguration is a public type that is used to allow access to custom fields +// of the provider type SubscriptionEventConfiguration struct { Provider string `json:"providerId"` Topics []string `json:"topics"` @@ -43,13 +59,47 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } -type PublishEventConfiguration struct { +// publishData is a private type that is used to pass data from the engine to the provider +type publishData struct { Provider string `json:"providerId"` Topic string `json:"topic"` Event Event `json:"event"` FieldName string `json:"rootFieldName"` } +// PublishEventConfiguration returns the publish event configuration from the publishData type +func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { + return &PublishEventConfiguration{ + Provider: p.Provider, + Topic: p.Topic, + FieldName: p.FieldName, + } +} + +func (p *publishData) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + headers := p.Event.Headers + if headers == nil { + headers = make(map[string][]byte) + } + + headersBytes, err := json.Marshal(headers) + if err != nil { + return "", err + } + + return fmt.Sprintf(`{"topic":"%s", "event": {"data": %s, "key": "%s", "headers": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Topic, p.Event.Data, p.Event.Key, headersBytes, p.Provider, p.FieldName), nil +} + +// PublishEventConfiguration is a public type that is used to allow access to custom fields +// of the provider +type PublishEventConfiguration struct { + Provider string `json:"providerId"` + Topic string `json:"topic"` + FieldName string `json:"rootFieldName"` +} + // ProviderID returns the provider ID func (p *PublishEventConfiguration) ProviderID() string { return p.Provider @@ -65,38 +115,73 @@ func (p *PublishEventConfiguration) RootFieldName() string { return p.FieldName } -func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - headers := s.Event.Headers - if headers == nil { - headers = make(map[string][]byte) +type SubscriptionDataSource struct { + pubSub datasource.Adapter +} + +func (s *SubscriptionDataSource) SubscriptionEventConfiguration(input []byte) datasource.SubscriptionEventConfiguration { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return nil } + return &subscriptionConfiguration +} - headersBytes, err := json.Marshal(headers) +func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "topics") if err != nil { - return "", err + return err } - return fmt.Sprintf(`{"topic":"%s", "event": {"data": %s, "key": "%s", "headers": %s}, "providerId":"%s"}`, s.Topic, s.Event.Data, s.Event.Key, headersBytes, s.ProviderID()), nil + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater datasource.SubscriptionEventUpdater) error { + subConf := s.SubscriptionEventConfiguration(input) + if subConf == nil { + return fmt.Errorf("no subscription configuration found") + } + + conf, ok := subConf.(*SubscriptionEventConfiguration) + if !ok { + return fmt.Errorf("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, updater) } type PublishDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration PublishEventConfiguration - if err := json.Unmarshal(input, &publishConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error + _, errWrite := io.WriteString(out, `{"success": false}`) + return errWrite } - _, err := io.WriteString(out, `{"success": true}`) - return err + _, errWrite := io.WriteString(out, `{"success": true}`) + if errWrite != nil { + return errWrite + } + return nil } func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory.go b/router/pkg/pubsub/kafka/engine_datasource_factory.go index 30507bc13b..d89eb408b0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory.go @@ -8,6 +8,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type EventType int @@ -22,8 +23,9 @@ type EngineDataSourceFactory struct { eventType EventType topics []string providerId string + logger *zap.Logger - KafkaAdapter Adapter + KafkaAdapter datasource.Adapter } func (c *EngineDataSourceFactory) GetFieldName() string { @@ -50,7 +52,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri return "", fmt.Errorf("publish events should define one topic but received %d", len(c.topics)) } - evtCfg := PublishEventConfiguration{ + evtCfg := publishData{ Provider: c.providerId, Topic: c.topics[0], Event: Event{Data: eventData}, @@ -81,7 +83,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.Su _, err = xxh.Write(val) return err - }), nil + }, c.logger), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go index 0b4ea9c59c..5ceab4ae69 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go @@ -5,12 +5,14 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -33,11 +35,13 @@ func TestKafkaEngineDataSourceFactory(t *testing.T) { // TestEngineDataSourceFactoryWithMockAdapter tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Configure mock expectations for Publish - mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Topic == "test-topic" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"test":"data"}`) })).Return(nil) // Create the data source with mock adapter @@ -67,7 +71,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // TestEngineDataSourceFactory_GetResolveDataSource_WrongType tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactory_GetResolveDataSource_WrongType(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Create the data source with mock adapter pubsub := &EngineDataSourceFactory{ @@ -171,7 +175,7 @@ func TestKafkaEngineDataSourceFactory_UniqueRequestID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { factory := &EngineDataSourceFactory{ - KafkaAdapter: NewMockAdapter(t), + KafkaAdapter: datasource.NewMockProvider(t), } source, err := factory.ResolveDataSourceSubscription() require.NoError(t, err) diff --git a/router/pkg/pubsub/kafka/engine_datasource_test.go b/router/pkg/pubsub/kafka/engine_datasource_test.go index eed485b246..846203d6e0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_test.go @@ -5,54 +5,60 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) -func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { +func TestPublishData_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string - config PublishEventConfiguration + config publishData wantPattern string }{ { name: "simple configuration", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider", Topic: "test-topic", Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider"}`, + wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, }, { name: "with special characters", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Topic: "topic-with-hyphens", Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, { name: "with key", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Topic: "topic-with-hyphens", Event: Event{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, { name: "with headers", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Topic: "topic-with-hyphens", Event: Event{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, } @@ -65,11 +71,27 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { } } +func TestPublishData_PublishEventConfiguration(t *testing.T) { + data := publishData{ + Provider: "test-provider", + Topic: "test-topic", + FieldName: "test-field", + } + + evtCfg := &PublishEventConfiguration{ + Provider: data.Provider, + Topic: data.Topic, + FieldName: data.FieldName, + } + + assert.Equal(t, evtCfg, data.PublishEventConfiguration()) +} + func TestKafkaPublishDataSource_Load(t *testing.T) { tests := []struct { name string input string - mockSetup func(*MockAdapter) + mockSetup func(*datasource.MockProvider) expectError bool expectedOutput string expectPublished bool @@ -77,11 +99,12 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { { name: "successful publish", input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Topic == "test-topic" && - string(event.Event.Data) == `{"message":"hello"}` + event.Topic == "test-topic" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"message":"hello"}`) })).Return(nil) }, expectError: false, @@ -91,8 +114,8 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { { name: "publish error", input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("publish error")) }, expectError: false, // The Load method doesn't return the publish error directly expectedOutput: `{"success": false}`, @@ -101,7 +124,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { { name: "invalid input json", input: `{"invalid json":`, - mockSetup: func(m *MockAdapter) {}, + mockSetup: func(m *datasource.MockProvider) {}, expectError: true, expectPublished: false, }, @@ -109,7 +132,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) tt.mockSetup(mockAdapter) dataSource := &PublishDataSource{ @@ -134,7 +157,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { func TestKafkaPublishDataSource_LoadWithFiles(t *testing.T) { t.Run("panic on not implemented", func(t *testing.T) { dataSource := &PublishDataSource{ - pubSub: NewMockAdapter(t), + pubSub: datasource.NewMockProvider(t), } assert.Panics(t, func() { diff --git a/router/pkg/pubsub/kafka/mocks.go b/router/pkg/pubsub/kafka/mocks.go deleted file mode 100644 index 08faa08eb2..0000000000 --- a/router/pkg/pubsub/kafka/mocks.go +++ /dev/null @@ -1,261 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package kafka - -import ( - "context" - - mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" -) - -// NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockAdapter(t interface { - mock.TestingT - Cleanup(func()) -}) *MockAdapter { - mock := &MockAdapter{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockAdapter is an autogenerated mock type for the Adapter type -type MockAdapter struct { - mock.Mock -} - -type MockAdapter_Expecter struct { - mock *mock.Mock -} - -func (_m *MockAdapter) EXPECT() *MockAdapter_Expecter { - return &MockAdapter_Expecter{mock: &_m.Mock} -} - -// Publish provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { - ret := _mock.Called(ctx, event) - - if len(ret) == 0 { - panic("no return value specified for Publish") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishEventConfiguration) error); ok { - r0 = returnFunc(ctx, event) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Publish_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Publish' -type MockAdapter_Publish_Call struct { - *mock.Call -} - -// Publish is a helper method to define mock.On call -// - ctx context.Context -// - event PublishEventConfiguration -func (_e *MockAdapter_Expecter) Publish(ctx interface{}, event interface{}) *MockAdapter_Publish_Call { - return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, event)} -} - -func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, event PublishEventConfiguration)) *MockAdapter_Publish_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 PublishEventConfiguration - if args[1] != nil { - arg1 = args[1].(PublishEventConfiguration) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockAdapter_Publish_Call) Return(err error) *MockAdapter_Publish_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, event PublishEventConfiguration) error) *MockAdapter_Publish_Call { - _c.Call.Return(run) - return _c -} - -// Shutdown provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Shutdown(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Shutdown") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' -type MockAdapter_Shutdown_Call struct { - *mock.Call -} - -// Shutdown is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Shutdown(ctx interface{}) *MockAdapter_Shutdown_Call { - return &MockAdapter_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} -} - -func (_c *MockAdapter_Shutdown_Call) Run(run func(ctx context.Context)) *MockAdapter_Shutdown_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) Return(err error) *MockAdapter_Shutdown_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Shutdown_Call { - _c.Call.Return(run) - return _c -} - -// Startup provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Startup(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Startup") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' -type MockAdapter_Startup_Call struct { - *mock.Call -} - -// Startup is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Startup(ctx interface{}) *MockAdapter_Startup_Call { - return &MockAdapter_Startup_Call{Call: _e.mock.On("Startup", ctx)} -} - -func (_c *MockAdapter_Startup_Call) Run(run func(ctx context.Context)) *MockAdapter_Startup_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Startup_Call) Return(err error) *MockAdapter_Startup_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Startup_Call { - _c.Call.Return(run) - return _c -} - -// Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - ret := _mock.Called(ctx, event, updater) - - if len(ret) == 0 { - panic("no return value specified for Subscribe") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { - r0 = returnFunc(ctx, event, updater) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' -type MockAdapter_Subscribe_Call struct { - *mock.Call -} - -// Subscribe is a helper method to define mock.On call -// - ctx context.Context -// - event datasource.SubscriptionEventConfiguration -// - updater datasource.SubscriptionEventUpdater -func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { - return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} -} - -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 datasource.SubscriptionEventConfiguration - if args[1] != nil { - arg1 = args[1].(datasource.SubscriptionEventConfiguration) - } - var arg2 datasource.SubscriptionEventUpdater - if args[2] != nil { - arg2 = args[2].(datasource.SubscriptionEventUpdater) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { - _c.Call.Return(run) - return _c -} diff --git a/router/pkg/pubsub/kafka/provider_builder.go b/router/pkg/pubsub/kafka/provider_builder.go index c88cf814c2..c69a458eba 100644 --- a/router/pkg/pubsub/kafka/provider_builder.go +++ b/router/pkg/pubsub/kafka/provider_builder.go @@ -23,16 +23,15 @@ type ProviderBuilder struct { logger *zap.Logger hostName string routerListenAddr string - adapters map[string]Adapter } func (p *ProviderBuilder) TypeID() string { return providerTypeID } -func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.KafkaEventConfiguration) (datasource.EngineDataSourceFactory, error) { +func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.KafkaEventConfiguration, providers map[string]datasource.Provider) (datasource.EngineDataSourceFactory, error) { providerId := data.GetEngineEventConfiguration().GetProviderId() - adapter, ok := p.adapters[providerId] + provider, ok := providers[providerId] if !ok { return nil, fmt.Errorf("failed to get adapter for provider %s with ID %s", p.TypeID(), providerId) } @@ -52,18 +51,17 @@ func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.KafkaEventCo eventType: eventType, topics: data.GetTopics(), providerId: providerId, - KafkaAdapter: adapter, + KafkaAdapter: provider, + logger: p.logger, }, nil } func (p *ProviderBuilder) BuildProvider(provider config.KafkaEventSource, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { - adapter, pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, providerOpts) + pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, providerOpts) if err != nil { return nil, err } - p.adapters[provider.ID] = adapter - return pubSubProvider, nil } @@ -150,18 +148,18 @@ func buildKafkaOptions(eventSource config.KafkaEventSource, logger *zap.Logger) return opts, nil } -func buildProvider(ctx context.Context, provider config.KafkaEventSource, logger *zap.Logger, providerOpts datasource.ProviderOpts) (Adapter, datasource.Provider, error) { +func buildProvider(ctx context.Context, provider config.KafkaEventSource, logger *zap.Logger, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { kafkaOpts, err := buildKafkaOptions(provider, logger) if err != nil { - return nil, nil, fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", provider.ID, err) } adapter, err := NewProviderAdapter(ctx, logger, kafkaOpts, providerOpts) if err != nil { - return nil, nil, fmt.Errorf("failed to create adapter for Kafka provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to create adapter for Kafka provider with ID \"%s\": %w", provider.ID, err) } pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger) - return adapter, pubSubProvider, nil + return pubSubProvider, nil } func NewProviderBuilder( @@ -175,6 +173,5 @@ func NewProviderBuilder( logger: logger, hostName: hostName, routerListenAddr: routerListenAddr, - adapters: make(map[string]Adapter), } } diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index dcba74a03b..def1d19f81 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -25,18 +25,14 @@ const ( // Adapter defines the methods that a NATS adapter should implement type Adapter interface { - // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error - // Publish publishes the given event to the specified subject - Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error + datasource.Adapter // Request sends a request to the specified subject and writes the response to the given writer - Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error - // Startup initializes the adapter - Startup(ctx context.Context) error - // Shutdown gracefully shuts down the adapter - Shutdown(ctx context.Context) error + Request(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error } +// Ensure ProviderAdapter implements ProviderSubscriptionHooks +var _ datasource.Adapter = (*ProviderAdapter)(nil) + // ProviderAdapter implements the AdapterInterface for NATS pub/sub type ProviderAdapter struct { ctx context.Context @@ -80,11 +76,12 @@ func (p *ProviderAdapter) getDurableConsumerName(durableName string, subjects [] return fmt.Sprintf("%s-%x", durableName, subjHash.Sum64()), nil } -func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - subConf, ok := conf.(*SubscriptionEventConfiguration) +func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := cfg.(*SubscriptionEventConfiguration) if !ok { - return datasource.NewError("invalid event type for Kafka adapter", nil) + return datasource.NewError("subscription event not support by nats provider", nil) } + log := p.logger.With( zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), @@ -145,16 +142,16 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri log.Debug("subscription update", zap.String("message_subject", msg.Subject()), zap.ByteString("data", msg.Data())) p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: conf.ProviderID(), + ProviderId: subConf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject(), }) - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: msg.Data(), Headers: msg.Headers(), - }) + }}) // Acknowledge the message after it has been processed ackErr := msg.Ack() @@ -191,18 +188,16 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri select { case msg := <-msgChan: log.Debug("subscription update", zap.String("message_subject", msg.Subject), zap.ByteString("data", msg.Data)) - p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: conf.ProviderID(), + ProviderId: subConf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject, }) - - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: msg.Data, Headers: msg.Header, - }) + }}) case <-p.ctx.Done(): // When the application context is done, we stop the subscriptions for _, subscription := range subscriptions { @@ -230,73 +225,107 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri return nil } -func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error { +func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + pubConf, ok := conf.(*PublishAndRequestEventConfiguration) + if !ok { + return datasource.NewError("publish event not support by nats provider", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", pubConf.ProviderID()), zap.String("method", "publish"), - zap.String("subject", event.Subject), + zap.String("subject", pubConf.Subject), ) if p.client == nil { return datasource.NewError("nats client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Event.Data)) + log.Debug("publish", zap.Int("event_count", len(events))) - err := p.client.Publish(event.Subject, event.Event.Data) - if err != nil { - log.Error("publish error", zap.Error(err)) - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), - StreamOperationName: natsPublish, - ProviderType: metric.ProviderTypeNats, - ErrorType: "publish_error", - DestinationName: event.Subject, - }) - return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", event.Subject), err) - } else { - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), - StreamOperationName: natsPublish, - ProviderType: metric.ProviderTypeNats, - DestinationName: event.Subject, - }) + for _, streamEvent := range events { + natsEvent, ok := streamEvent.(*Event) + if !ok { + return datasource.NewError("invalid event type for NATS adapter", nil) + } + + err := p.client.Publish(pubConf.Subject, natsEvent.Data) + if err != nil { + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: natsPublish, + ProviderType: metric.ProviderTypeNats, + ErrorType: "publish_error", + DestinationName: pubConf.Subject, + }) + log.Error( + "publish error", + zap.Error(err), + zap.String("provider_id", pubConf.ProviderID()), + zap.String("provider_type", string(pubConf.ProviderType())), + zap.String("field_name", pubConf.RootFieldName()), + ) + return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", pubConf.Subject), err) + } } + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: natsPublish, + ProviderType: metric.ProviderTypeNats, + DestinationName: pubConf.Subject, + }) + return nil } -func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { +func (p *ProviderAdapter) Request(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error { + reqConf, ok := cfg.(*PublishAndRequestEventConfiguration) + if !ok { + return datasource.NewError("publish event not support by nats provider", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", cfg.ProviderID()), zap.String("method", "request"), - zap.String("subject", event.Subject), + zap.String("subject", reqConf.Subject), ) if p.client == nil { return datasource.NewError("nats client not initialized", nil) } - log.Debug("request", zap.ByteString("data", event.Event.Data)) + natsEvent, ok := event.(*Event) + if !ok { + return datasource.NewError("invalid event type for NATS adapter", nil) + } + + log.Debug("request", zap.ByteString("data", natsEvent.Data)) - msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Event.Data) + msg, err := p.client.RequestWithContext(ctx, reqConf.Subject, natsEvent.Data) if err != nil { - log.Error("request error", zap.Error(err)) + log.Error( + "request error", + zap.Error(err), + zap.String("provider_id", reqConf.ProviderID()), + zap.String("provider_type", string(reqConf.ProviderType())), + zap.String("field_name", reqConf.RootFieldName()), + ) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: reqConf.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, ErrorType: "request_error", - DestinationName: event.Subject, + DestinationName: reqConf.Subject, }) - return datasource.NewError(fmt.Sprintf("error requesting from NATS subject %s", event.Subject), err) + return datasource.NewError(fmt.Sprintf("error requesting from NATS subject %s", reqConf.Subject), err) } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: reqConf.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, - DestinationName: event.Subject, + DestinationName: reqConf.Subject, }) // We don't collect metrics on err here as it's an error related to the writer diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index 0fa41e5480..3b2014a71a 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -6,9 +6,13 @@ import ( "encoding/json" "fmt" "io" + "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // Event represents an event from NATS @@ -21,6 +25,16 @@ func (e *Event) GetData() []byte { return e.Data } +func (e *Event) Clone() datasource.StreamEvent { + e2 := *e + e2.Data = slices.Clone(e.Data) + e2.Headers = make(map[string][]string, len(e.Headers)) + for k, v := range e.Headers { + e2.Headers[k] = slices.Clone(v) + } + return &e2 +} + type StreamConfiguration struct { Consumer string `json:"consumer"` ConsumerInactiveThreshold int32 `json:"consumerInactiveThreshold"` @@ -49,13 +63,34 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } -type PublishAndRequestEventConfiguration struct { +// publishData is a private type that is used to pass data from the engine to the provider +type publishData struct { Provider string `json:"providerId"` Subject string `json:"subject"` Event Event `json:"event"` FieldName string `json:"rootFieldName"` } +func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { + return &PublishAndRequestEventConfiguration{ + Provider: p.Provider, + Subject: p.Subject, + FieldName: p.FieldName, + } +} + +func (p *publishData) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + return fmt.Sprintf(`{"subject":"%s", "event": {"data": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Subject, p.Event.Data, p.Provider, p.FieldName), nil +} + +type PublishAndRequestEventConfiguration struct { + Provider string `json:"providerId"` + Subject string `json:"subject"` + FieldName string `json:"rootFieldName"` +} + // ProviderID returns the provider ID func (p *PublishAndRequestEventConfiguration) ProviderID() string { return p.Provider @@ -71,25 +106,68 @@ func (p *PublishAndRequestEventConfiguration) RootFieldName() string { return p.FieldName } -func (p *PublishAndRequestEventConfiguration) MarshalJSONTemplate() (string, error) { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - return fmt.Sprintf(`{"subject":"%s", "event": {"data": %s}, "providerId":"%s"}`, p.Subject, p.Event.Data, p.ProviderID()), nil +type SubscriptionSource struct { + pubSub datasource.Adapter +} + +func (s *SubscriptionSource) SubscriptionEventConfiguration(input []byte) datasource.SubscriptionEventConfiguration { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return nil + } + return &subscriptionConfiguration +} + +func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + + val, _, _, err := jsonparser.Get(input, "subjects") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater datasource.SubscriptionEventUpdater) error { + subConf := s.SubscriptionEventConfiguration(input) + if subConf == nil { + return fmt.Errorf("no subscription configuration found") + } + + conf, ok := subConf.(*SubscriptionEventConfiguration) + if !ok { + return fmt.Errorf("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, updater) } type NatsPublishDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration PublishAndRequestEventConfiguration - if err := json.Unmarshal(input, &publishConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error + _, errWrite := io.WriteString(out, `{"success": false}`) + return errWrite } _, err := io.WriteString(out, `{"success": true}`) return err @@ -100,16 +178,26 @@ func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, } type NatsRequestDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var subscriptionConfiguration PublishAndRequestEventConfiguration - if err := json.Unmarshal(input, &subscriptionConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - return s.pubSub.Request(ctx, subscriptionConfiguration, out) + providerBase, ok := s.pubSub.(*datasource.PubSubProvider) + if !ok { + return fmt.Errorf("adapter for provider %s is not of the right type", publishData.Provider) + } + + adapter, ok := providerBase.Adapter.(Adapter) + if !ok { + return fmt.Errorf("adapter for provider %s is not of the right type", publishData.Provider) + } + + return adapter.Request(ctx, publishData.PublishEventConfiguration(), &publishData.Event, out) } func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { diff --git a/router/pkg/pubsub/nats/engine_datasource_factory.go b/router/pkg/pubsub/nats/engine_datasource_factory.go index 36d3932e0d..d88d25b868 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory.go @@ -8,8 +8,8 @@ import ( "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type EventType int @@ -21,12 +21,13 @@ const ( ) type EngineDataSourceFactory struct { - NatsAdapter Adapter + NatsAdapter datasource.Adapter fieldName string eventType EventType subjects []string providerId string + logger *zap.Logger withStreamConfiguration bool consumerName string @@ -64,11 +65,11 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri subject := c.subjects[0] - evtCfg := PublishAndRequestEventConfiguration{ + evtCfg := publishData{ Provider: c.providerId, Subject: subject, - Event: Event{Data: eventData}, FieldName: c.fieldName, + Event: Event{Data: eventData}, } return evtCfg.MarshalJSONTemplate() @@ -95,7 +96,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.Su _, err = xxh.Write(val) return err - }), nil + }, c.logger), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { diff --git a/router/pkg/pubsub/nats/engine_datasource_factory_test.go b/router/pkg/pubsub/nats/engine_datasource_factory_test.go index a94c8d5941..053ff0d702 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory_test.go @@ -6,14 +6,17 @@ import ( "encoding/json" "errors" "io" + "strings" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) func TestNatsEngineDataSourceFactory(t *testing.T) { @@ -36,8 +39,10 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { mockAdapter := NewMockAdapter(t) // Configure mock expectations for Publish - mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Subject == "test-subject" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"test":"data"}`) })).Return(nil) // Create the data source with mock adapter @@ -167,12 +172,15 @@ func TestNatsEngineDataSourceFactoryWithStreamConfiguration(t *testing.T) { func TestEngineDataSourceFactory_RequestDataSource(t *testing.T) { // Create mock adapter mockAdapter := NewMockAdapter(t) + provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop()) // Configure mock expectations for Request - mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Subject == "test-subject" + }), mock.MatchedBy(func(event datasource.StreamEvent) bool { + return event != nil && strings.EqualFold(string(event.GetData()), `{"test":"data"}`) }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { - w := args.Get(2).(io.Writer) + w := args.Get(3).(io.Writer) w.Write([]byte(`{"response": "test"}`)) }) @@ -182,7 +190,7 @@ func TestEngineDataSourceFactory_RequestDataSource(t *testing.T) { eventType: EventTypeRequest, subjects: []string{"test-subject"}, fieldName: "testField", - NatsAdapter: mockAdapter, + NatsAdapter: provider, } // Get the data source diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go index 5d060d2c0d..8665f42181 100644 --- a/router/pkg/pubsub/nats/engine_datasource_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -6,36 +6,41 @@ import ( "encoding/json" "errors" "io" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "go.uber.org/zap" ) func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string - config PublishAndRequestEventConfiguration + config publishData wantPattern string }{ { name: "simple configuration", - config: PublishAndRequestEventConfiguration{ + config: publishData{ Provider: "test-provider", Subject: "test-subject", Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + FieldName: "test-field", }, - wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, + wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, }, { name: "with special characters", - config: PublishAndRequestEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Subject: "subject-with-hyphens", Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + FieldName: "test-field", }, - wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, + wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, } @@ -43,11 +48,27 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result, err := tt.config.MarshalJSONTemplate() assert.NoError(t, err) - assert.Equal(t, tt.wantPattern, result) + assert.Equal(t, tt.wantPattern, string(result)) }) } } +func TestPublishData_PublishEventConfiguration(t *testing.T) { + data := publishData{ + Provider: "test-provider", + Subject: "test-subject", + FieldName: "test-field", + } + + evtCfg := &PublishAndRequestEventConfiguration{ + Provider: data.Provider, + Subject: data.Subject, + FieldName: data.FieldName, + } + + assert.Equal(t, evtCfg, data.PublishEventConfiguration()) +} + func TestNatsPublishDataSource_Load(t *testing.T) { tests := []struct { name string @@ -61,10 +82,11 @@ func TestNatsPublishDataSource_Load(t *testing.T) { name: "successful publish", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Subject == "test-subject" && - string(event.Event.Data) == `{"message":"hello"}` + event.Subject == "test-subject" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"message":"hello"}`) })).Return(nil) }, expectError: false, @@ -75,7 +97,7 @@ func TestNatsPublishDataSource_Load(t *testing.T) { name: "publish error", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + m.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("publish error")) }, expectError: false, // The Load method doesn't return the publish error directly expectedOutput: `{"success": false}`, @@ -136,13 +158,14 @@ func TestNatsRequestDataSource_Load(t *testing.T) { name: "successful request", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + m.On("Request", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Subject == "test-subject" && - string(event.Event.Data) == `{"message":"hello"}` + event.Subject == "test-subject" + }), mock.MatchedBy(func(event datasource.StreamEvent) bool { + return event != nil && strings.EqualFold(string(event.GetData()), `{"message":"hello"}`) }), mock.Anything).Run(func(args mock.Arguments) { // Write response to the output buffer - w := args.Get(2).(io.Writer) + w := args.Get(3).(io.Writer) _, _ = w.Write([]byte(`{"response":"success"}`)) }).Return(nil) }, @@ -153,7 +176,7 @@ func TestNatsRequestDataSource_Load(t *testing.T) { name: "request error", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Request", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("request error")) + m.On("Request", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("request error")) }, expectError: true, expectedOutput: "", @@ -170,10 +193,11 @@ func TestNatsRequestDataSource_Load(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockAdapter := NewMockAdapter(t) + provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop()) tt.mockSetup(mockAdapter) dataSource := &NatsRequestDataSource{ - pubSub: mockAdapter, + pubSub: provider, } ctx := context.Background() diff --git a/router/pkg/pubsub/nats/mocks.go b/router/pkg/pubsub/nats/mocks.go index 0bc3ada5f0..cfe1a57d95 100644 --- a/router/pkg/pubsub/nats/mocks.go +++ b/router/pkg/pubsub/nats/mocks.go @@ -40,16 +40,16 @@ func (_m *MockAdapter) EXPECT() *MockAdapter_Expecter { } // Publish provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error { - ret := _mock.Called(ctx, event) +func (_mock *MockAdapter) Publish(ctx context.Context, cfg datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + ret := _mock.Called(ctx, cfg, events) if len(ret) == 0 { panic("no return value specified for Publish") } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishAndRequestEventConfiguration) error); ok { - r0 = returnFunc(ctx, event) + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.PublishEventConfiguration, []datasource.StreamEvent) error); ok { + r0 = returnFunc(ctx, cfg, events) } else { r0 = ret.Error(0) } @@ -63,24 +63,30 @@ type MockAdapter_Publish_Call struct { // Publish is a helper method to define mock.On call // - ctx context.Context -// - event PublishAndRequestEventConfiguration -func (_e *MockAdapter_Expecter) Publish(ctx interface{}, event interface{}) *MockAdapter_Publish_Call { - return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, event)} +// - cfg datasource.PublishEventConfiguration +// - events []datasource.StreamEvent +func (_e *MockAdapter_Expecter) Publish(ctx interface{}, cfg interface{}, events interface{}) *MockAdapter_Publish_Call { + return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, cfg, events)} } -func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, event PublishAndRequestEventConfiguration)) *MockAdapter_Publish_Call { +func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, events []datasource.StreamEvent)) *MockAdapter_Publish_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 PublishAndRequestEventConfiguration + var arg1 datasource.PublishEventConfiguration if args[1] != nil { - arg1 = args[1].(PublishAndRequestEventConfiguration) + arg1 = args[1].(datasource.PublishEventConfiguration) + } + var arg2 []datasource.StreamEvent + if args[2] != nil { + arg2 = args[2].([]datasource.StreamEvent) } run( arg0, arg1, + arg2, ) }) return _c @@ -91,22 +97,22 @@ func (_c *MockAdapter_Publish_Call) Return(err error) *MockAdapter_Publish_Call return _c } -func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, event PublishAndRequestEventConfiguration) error) *MockAdapter_Publish_Call { +func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, events []datasource.StreamEvent) error) *MockAdapter_Publish_Call { _c.Call.Return(run) return _c } // Request provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { - ret := _mock.Called(ctx, event, w) +func (_mock *MockAdapter) Request(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error { + ret := _mock.Called(ctx, cfg, event, w) if len(ret) == 0 { panic("no return value specified for Request") } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishAndRequestEventConfiguration, io.Writer) error); ok { - r0 = returnFunc(ctx, event, w) + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.PublishEventConfiguration, datasource.StreamEvent, io.Writer) error); ok { + r0 = returnFunc(ctx, cfg, event, w) } else { r0 = ret.Error(0) } @@ -120,30 +126,36 @@ type MockAdapter_Request_Call struct { // Request is a helper method to define mock.On call // - ctx context.Context -// - event PublishAndRequestEventConfiguration +// - cfg datasource.PublishEventConfiguration +// - event datasource.StreamEvent // - w io.Writer -func (_e *MockAdapter_Expecter) Request(ctx interface{}, event interface{}, w interface{}) *MockAdapter_Request_Call { - return &MockAdapter_Request_Call{Call: _e.mock.On("Request", ctx, event, w)} +func (_e *MockAdapter_Expecter) Request(ctx interface{}, cfg interface{}, event interface{}, w interface{}) *MockAdapter_Request_Call { + return &MockAdapter_Request_Call{Call: _e.mock.On("Request", ctx, cfg, event, w)} } -func (_c *MockAdapter_Request_Call) Run(run func(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer)) *MockAdapter_Request_Call { +func (_c *MockAdapter_Request_Call) Run(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer)) *MockAdapter_Request_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 PublishAndRequestEventConfiguration + var arg1 datasource.PublishEventConfiguration if args[1] != nil { - arg1 = args[1].(PublishAndRequestEventConfiguration) + arg1 = args[1].(datasource.PublishEventConfiguration) } - var arg2 io.Writer + var arg2 datasource.StreamEvent if args[2] != nil { - arg2 = args[2].(io.Writer) + arg2 = args[2].(datasource.StreamEvent) + } + var arg3 io.Writer + if args[3] != nil { + arg3 = args[3].(io.Writer) } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -154,7 +166,7 @@ func (_c *MockAdapter_Request_Call) Return(err error) *MockAdapter_Request_Call return _c } -func (_c *MockAdapter_Request_Call) RunAndReturn(run func(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error) *MockAdapter_Request_Call { +func (_c *MockAdapter_Request_Call) RunAndReturn(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error) *MockAdapter_Request_Call { _c.Call.Return(run) return _c } @@ -262,8 +274,8 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - ret := _mock.Called(ctx, event, updater) +func (_mock *MockAdapter) Subscribe(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + ret := _mock.Called(ctx, cfg, updater) if len(ret) == 0 { panic("no return value specified for Subscribe") @@ -271,7 +283,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.Subscr var r0 error if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { - r0 = returnFunc(ctx, event, updater) + r0 = returnFunc(ctx, cfg, updater) } else { r0 = ret.Error(0) } @@ -285,13 +297,13 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event datasource.SubscriptionEventConfiguration +// - cfg datasource.SubscriptionEventConfiguration // - updater datasource.SubscriptionEventUpdater -func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { - return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} +func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, cfg interface{}, updater interface{}) *MockAdapter_Subscribe_Call { + return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, cfg, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -319,7 +331,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c } diff --git a/router/pkg/pubsub/nats/provider_builder.go b/router/pkg/pubsub/nats/provider_builder.go index e3ba5f7cb0..2b07c4217a 100644 --- a/router/pkg/pubsub/nats/provider_builder.go +++ b/router/pkg/pubsub/nats/provider_builder.go @@ -20,16 +20,15 @@ type ProviderBuilder struct { logger *zap.Logger hostName string routerListenAddr string - adapters map[string]Adapter } func (p *ProviderBuilder) TypeID() string { return providerTypeID } -func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventConfiguration) (datasource.EngineDataSourceFactory, error) { +func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventConfiguration, providers map[string]datasource.Provider) (datasource.EngineDataSourceFactory, error) { providerId := data.GetEngineEventConfiguration().GetProviderId() - adapter, ok := p.adapters[providerId] + provider, ok := providers[providerId] if !ok { return nil, fmt.Errorf("failed to get adapter for provider %s with ID %s", p.TypeID(), providerId) } @@ -46,12 +45,13 @@ func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventCon return nil, fmt.Errorf("unsupported event type: %s", data.GetEngineEventConfiguration().GetType()) } dataSourceFactory := &EngineDataSourceFactory{ - NatsAdapter: adapter, + NatsAdapter: provider, fieldName: data.GetEngineEventConfiguration().GetFieldName(), eventType: eventType, subjects: data.GetSubjects(), providerId: providerId, withStreamConfiguration: data.GetStreamConfiguration() != nil, + logger: p.logger, } if data.GetStreamConfiguration() != nil { @@ -65,11 +65,10 @@ func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventCon } func (p *ProviderBuilder) BuildProvider(provider config.NatsEventSource, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { - adapter, pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, p.hostName, p.routerListenAddr, providerOpts) + pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, p.hostName, p.routerListenAddr, providerOpts) if err != nil { return nil, err } - p.adapters[provider.ID] = adapter return pubSubProvider, nil } @@ -118,18 +117,18 @@ func buildNatsOptions(eventSource config.NatsEventSource, logger *zap.Logger) ([ return opts, nil } -func buildProvider(ctx context.Context, provider config.NatsEventSource, logger *zap.Logger, hostName string, routerListenAddr string, providerOpts datasource.ProviderOpts) (Adapter, datasource.Provider, error) { +func buildProvider(ctx context.Context, provider config.NatsEventSource, logger *zap.Logger, hostName string, routerListenAddr string, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { options, err := buildNatsOptions(provider, logger) if err != nil { - return nil, nil, fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", provider.ID, err) } adapter, err := NewAdapter(ctx, logger, provider.URL, options, hostName, routerListenAddr, providerOpts) if err != nil { - return nil, nil, fmt.Errorf("failed to create adapter for Nats provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to create adapter for Nats provider with ID \"%s\": %w", provider.ID, err) } pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger) - return adapter, pubSubProvider, nil + return pubSubProvider, nil } func NewProviderBuilder( @@ -143,6 +142,5 @@ func NewProviderBuilder( logger: logger, hostName: hostName, routerListenAddr: routerListenAddr, - adapters: make(map[string]Adapter), } } diff --git a/router/pkg/pubsub/pubsub.go b/router/pkg/pubsub/pubsub.go index 085de71a0e..19c908712e 100644 --- a/router/pkg/pubsub/pubsub.go +++ b/router/pkg/pubsub/pubsub.go @@ -51,11 +51,6 @@ func (e *ProviderNotDefinedError) Error() string { return fmt.Sprintf("%s provider with ID %s is not defined", e.ProviderTypeID, e.ProviderID) } -// Hooks contains hooks for the pubsub providers and data sources -type Hooks struct { - SubscriptionOnStart []pubsub_datasource.SubscriptionOnStartFn -} - // BuildProvidersAndDataSources is a generic function that builds providers and data sources for the given // EventsConfiguration and DataSourceConfigurationWithMetadata func BuildProvidersAndDataSources( @@ -66,7 +61,7 @@ func BuildProvidersAndDataSources( dsConfs []DataSourceConfigurationWithMetadata, hostName string, routerListenAddr string, - hooks Hooks, + hooks pubsub_datasource.Hooks, ) ([]pubsub_datasource.Provider, []plan.DataSource, error) { if store == nil { store = metric.NewNoopStreamMetricStore() @@ -88,7 +83,9 @@ func BuildProvidersAndDataSources( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, kafkaPubSubProviders...) + for _, provider := range kafkaPubSubProviders { + pubSubProviders = append(pubSubProviders, provider) + } outs = append(outs, kafkaOuts...) // initialize NATS providers and data sources @@ -104,7 +101,9 @@ func BuildProvidersAndDataSources( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, natsPubSubProviders...) + for _, provider := range natsPubSubProviders { + pubSubProviders = append(pubSubProviders, provider) + } outs = append(outs, natsOuts...) // initialize Redis providers and data sources @@ -120,7 +119,9 @@ func BuildProvidersAndDataSources( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, redisPubSubProviders...) + for _, provider := range redisPubSubProviders { + pubSubProviders = append(pubSubProviders, provider) + } outs = append(outs, redisOuts...) return pubSubProviders, outs, nil @@ -129,12 +130,11 @@ func BuildProvidersAndDataSources( func build[P GetID, E GetEngineEventConfiguration]( ctx context.Context, builder pubsub_datasource.ProviderBuilder[P, E], - providersData []P, - dsConfs []dsConfAndEvents[E], + providersData []P, dsConfs []dsConfAndEvents[E], store metric.StreamMetricStore, - hooks Hooks, -) ([]pubsub_datasource.Provider, []plan.DataSource, error) { - var pubSubProviders []pubsub_datasource.Provider + hooks pubsub_datasource.Hooks, +) (map[string]pubsub_datasource.Provider, []plan.DataSource, error) { + pubSubProviders := make(map[string]pubsub_datasource.Provider) var outs []plan.DataSource // check used providers @@ -148,7 +148,6 @@ func build[P GetID, E GetEngineEventConfiguration]( } // initialize providers if used - providerIds := []string{} for _, providerData := range providersData { if !slices.Contains(usedProviderIds, providerData.GetID()) { continue @@ -159,13 +158,13 @@ func build[P GetID, E GetEngineEventConfiguration]( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, provider) - providerIds = append(providerIds, provider.ID()) + provider.SetHooks(hooks) + pubSubProviders[provider.ID()] = provider } // check if all used providers are initialized for _, providerId := range usedProviderIds { - if !slices.Contains(providerIds, providerId) { + if _, ok := pubSubProviders[providerId]; !ok { return pubSubProviders, nil, &ProviderNotDefinedError{ ProviderID: providerId, ProviderTypeID: builder.TypeID(), @@ -176,7 +175,12 @@ func build[P GetID, E GetEngineEventConfiguration]( // build data sources for each event for _, dsConf := range dsConfs { for i, event := range dsConf.events { - plannerConfig := pubsub_datasource.NewPlannerConfig(builder, event, hooks.SubscriptionOnStart) + plannerConfig := pubsub_datasource.NewPlannerConfig( + builder, + event, + pubSubProviders, + hooks, + ) out, err := plan.NewDataSourceConfiguration( dsConf.dsConf.Configuration.Id+"-"+builder.TypeID()+"-"+strconv.Itoa(i), pubsub_datasource.NewPlannerFactory(ctx, plannerConfig), diff --git a/router/pkg/pubsub/pubsub_test.go b/router/pkg/pubsub/pubsub_test.go index 976980b4ff..39444689ac 100644 --- a/router/pkg/pubsub/pubsub_test.go +++ b/router/pkg/pubsub/pubsub_test.go @@ -62,13 +62,17 @@ func TestBuild_OK(t *testing.T) { } mockPubSubProvider.On("ID").Return("provider-1") + mockPubSubProvider.On("SetHooks", datasource.Hooks{ + OnReceiveEvents: []datasource.OnReceiveEventsFn(nil), + OnPublishEvents: []datasource.OnPublishEventsFn(nil), + }).Return(nil) mockBuilder.On("TypeID").Return("nats") mockBuilder.On("BuildProvider", natsEventSources[0]).Return(mockPubSubProvider, nil) // ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -124,7 +128,7 @@ func TestBuild_ProviderError(t *testing.T) { mockBuilder.On("BuildProvider", natsEventSources[0], mock.Anything).Return(nil, errors.New("provider error")) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.Error(t, err) @@ -179,7 +183,7 @@ func TestBuild_ShouldGetAnErrorIfProviderIsNotDefined(t *testing.T) { mockBuilder.On("TypeID").Return("nats") // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.Error(t, err) @@ -237,13 +241,17 @@ func TestBuild_ShouldNotInitializeProviderIfNotUsed(t *testing.T) { } mockPubSubUsedProvider.On("ID").Return("provider-2") + mockPubSubUsedProvider.On("SetHooks", datasource.Hooks{ + OnReceiveEvents: []datasource.OnReceiveEventsFn(nil), + OnPublishEvents: []datasource.OnPublishEventsFn(nil), + }).Return(nil) mockBuilder.On("TypeID").Return("nats") mockBuilder.On("BuildProvider", natsEventSources[1], mock.Anything). Return(mockPubSubUsedProvider, nil) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -294,7 +302,7 @@ func TestBuildProvidersAndDataSources_Nats_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) + }, nil, zap.NewNop(), dsConfs, "host", "addr", datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -347,7 +355,7 @@ func TestBuildProvidersAndDataSources_Kafka_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) + }, nil, zap.NewNop(), dsConfs, "host", "addr", datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -400,7 +408,7 @@ func TestBuildProvidersAndDataSources_Redis_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) + }, nil, zap.NewNop(), dsConfs, "host", "addr", datasource.Hooks{}) // Assertions assert.NoError(t, err) diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 5cb0055a36..8c65bc3413 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -17,19 +17,10 @@ const ( redisReceive = "receive" ) -// Adapter defines the methods that a Redis adapter should implement -type Adapter interface { - // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error - // Publish publishes the given event to the specified channel - Publish(ctx context.Context, event PublishEventConfiguration) error - // Startup initializes the adapter - Startup(ctx context.Context) error - // Shutdown gracefully shuts down the adapter - Shutdown(ctx context.Context) error -} +// Ensure ProviderAdapter implements ProviderSubscriptionHooks +var _ datasource.Adapter = (*ProviderAdapter)(nil) -func NewProviderAdapter(ctx context.Context, logger *zap.Logger, urls []string, clusterEnabled bool, opts datasource.ProviderOpts) Adapter { +func NewProviderAdapter(ctx context.Context, logger *zap.Logger, urls []string, clusterEnabled bool, opts datasource.ProviderOpts) datasource.Adapter { ctx, cancel := context.WithCancel(ctx) if logger == nil { logger = zap.NewNop() @@ -96,10 +87,11 @@ func (p *ProviderAdapter) Shutdown(ctx context.Context) error { func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { subConf, ok := conf.(*SubscriptionEventConfiguration) if !ok { - return datasource.NewError("invalid event type for Kafka adapter", nil) + return datasource.NewError("subscription event not support by redis provider", nil) } + log := p.logger.With( - zap.String("provider_id", subConf.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "subscribe"), zap.Strings("channels", subConf.Channels), ) @@ -136,9 +128,9 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri ProviderType: metric.ProviderTypeRedis, DestinationName: msg.Channel, }) - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: []byte(msg.Payload), - }) + }}) case <-p.ctx.Done(): // When the application context is done, we stop the subscription if it is not already done log.Debug("application context done, stopping subscription") @@ -156,41 +148,59 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri return nil } -func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { +func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + pubConf, ok := conf.(*PublishEventConfiguration) + if !ok { + return datasource.NewError("publish event not support by redis provider", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "publish"), - zap.String("channel", event.Channel), + zap.String("channel", pubConf.Channel), ) - log.Debug("publish", zap.ByteString("data", event.Event.Data)) - - data, dataErr := event.Event.Data.MarshalJSON() - if dataErr != nil { - log.Error("error marshalling data", zap.Error(dataErr)) - return datasource.NewError("error marshalling data", dataErr) - } if p.conn == nil { return datasource.NewError("redis connection not initialized", nil) } - intCmd := p.conn.Publish(ctx, event.Channel, data) - if intCmd.Err() != nil { - log.Error("publish error", zap.Error(intCmd.Err())) - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), - StreamOperationName: redisPublish, - ProviderType: metric.ProviderTypeRedis, - ErrorType: "publish_error", - DestinationName: event.Channel, - }) - return datasource.NewError(fmt.Sprintf("error publishing to Redis PubSub channel %s", event.Channel), intCmd.Err()) + + if len(events) == 0 { + return nil + } + + log.Debug("publish", zap.Int("event_count", len(events))) + + for _, streamEvent := range events { + redisEvent, ok := streamEvent.(*Event) + if !ok { + return datasource.NewError("invalid event type for Redis adapter", nil) + } + + data, dataErr := redisEvent.Data.MarshalJSON() + if dataErr != nil { + log.Error("error marshalling data", zap.Error(dataErr)) + return datasource.NewError("error marshalling data", dataErr) + } + + intCmd := p.conn.Publish(ctx, pubConf.Channel, data) + if intCmd.Err() != nil { + log.Error("publish error", zap.Error(intCmd.Err())) + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: redisPublish, + ProviderType: metric.ProviderTypeRedis, + ErrorType: "publish_error", + DestinationName: pubConf.Channel, + }) + return datasource.NewError(fmt.Sprintf("error publishing to Redis PubSub channel %s", pubConf.Channel), intCmd.Err()) + } } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: pubConf.ProviderID(), StreamOperationName: redisPublish, ProviderType: metric.ProviderTypeRedis, - DestinationName: event.Channel, + DestinationName: pubConf.Channel, }) return nil } diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index 3a685fe9b0..e796b60e66 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -6,9 +6,13 @@ import ( "encoding/json" "fmt" "io" + "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // Event represents an event from Redis @@ -20,6 +24,12 @@ func (e *Event) GetData() []byte { return e.Data } +func (e *Event) Clone() datasource.StreamEvent { + return &Event{ + Data: slices.Clone(e.Data), + } +} + // SubscriptionEventConfiguration contains configuration for subscription events type SubscriptionEventConfiguration struct { Provider string `json:"providerId"` @@ -42,11 +52,31 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// publishData is a private type that is used to pass data from the engine to the provider + +type publishData struct { + Provider string `json:"providerId"` + Channel string `json:"channel"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} + +func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { + return &PublishEventConfiguration{ + Provider: p.Provider, + Channel: p.Channel, + FieldName: p.FieldName, + } +} + +func (p *publishData) MarshalJSONTemplate() (string, error) { + return fmt.Sprintf(`{"channel":"%s", "event": {"data": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Channel, p.Event.Data, p.Provider, p.FieldName), nil +} + // PublishEventConfiguration contains configuration for publish events type PublishEventConfiguration struct { Provider string `json:"providerId"` Channel string `json:"channel"` - Event Event `json:"event"` FieldName string `json:"rootFieldName"` } @@ -65,25 +95,77 @@ func (p *PublishEventConfiguration) RootFieldName() string { return p.FieldName } -func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { - return fmt.Sprintf(`{"channel":"%s", "event": {"data": %s}, "providerId":"%s"}`, s.Channel, s.Event.Data, s.ProviderID()), nil +// SubscriptionDataSource implements resolve.SubscriptionDataSource for Redis +type SubscriptionDataSource struct { + pubSub datasource.Adapter +} + +func (s *SubscriptionDataSource) SubscriptionEventConfiguration(input []byte) datasource.SubscriptionEventConfiguration { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return nil + } + return &subscriptionConfiguration +} + +// UniqueRequestID computes a unique ID for the subscription request +func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "channels") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +// Start starts the subscription +func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater datasource.SubscriptionEventUpdater) error { + subConf := s.SubscriptionEventConfiguration(input) + if subConf == nil { + return fmt.Errorf("no subscription configuration found") + } + + conf, ok := subConf.(*SubscriptionEventConfiguration) + if !ok { + return fmt.Errorf("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, updater) +} + +// LoadInitialData implements the interface method (not used for this subscription type) +func (s *SubscriptionDataSource) LoadInitialData(ctx context.Context) (initial []byte, err error) { + return nil, nil } // PublishDataSource implements resolve.DataSource for Redis publishing type PublishDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } // Load processes a request to publish to Redis func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration PublishEventConfiguration - if err := json.Unmarshal(input, &publishConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error + _, errWrite := io.WriteString(out, `{"success": false}`) + return errWrite } _, err := io.WriteString(out, `{"success": true}`) return err diff --git a/router/pkg/pubsub/redis/engine_datasource_factory.go b/router/pkg/pubsub/redis/engine_datasource_factory.go index bce913e54e..46f22e29b9 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory.go @@ -9,6 +9,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type EventType int @@ -20,12 +21,13 @@ const ( // EngineDataSourceFactory implements the datasource.EngineDataSourceFactory interface for Redis type EngineDataSourceFactory struct { - RedisAdapter Adapter + RedisAdapter datasource.Adapter fieldName string eventType EventType channels []string providerId string + logger *zap.Logger } func (c *EngineDataSourceFactory) GetFieldName() string { @@ -60,11 +62,11 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri channel := channels[0] providerId := c.providerId - evtCfg := PublishEventConfiguration{ + evtCfg := publishData{ Provider: providerId, Channel: channel, - Event: Event{Data: eventData}, FieldName: c.fieldName, + Event: Event{Data: eventData}, } return evtCfg.MarshalJSONTemplate() @@ -92,7 +94,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.Su _, err = xxh.Write(val) return err - }), nil + }, c.logger), nil } // ResolveDataSourceSubscriptionInput builds the input for the subscription data source diff --git a/router/pkg/pubsub/redis/engine_datasource_factory_test.go b/router/pkg/pubsub/redis/engine_datasource_factory_test.go index f96691583d..7dc4ade017 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory_test.go @@ -5,12 +5,14 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -33,11 +35,13 @@ func TestRedisEngineDataSourceFactory(t *testing.T) { // TestEngineDataSourceFactoryWithMockAdapter tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Configure mock expectations for Publish - mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Channel == "test-channel" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"test":"data"}`) })).Return(nil) // Create the data source with mock adapter @@ -67,7 +71,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // TestEngineDataSourceFactory_GetResolveDataSource_WrongType tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactory_GetResolveDataSource_WrongType(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Create the data source with mock adapter pubsub := &EngineDataSourceFactory{ @@ -210,7 +214,7 @@ func TestRedisEngineDataSourceFactory_UniqueRequestID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { factory := &EngineDataSourceFactory{ - RedisAdapter: NewMockAdapter(t), + RedisAdapter: datasource.NewMockProvider(t), } source, err := factory.ResolveDataSourceSubscription() require.NoError(t, err) diff --git a/router/pkg/pubsub/redis/engine_datasource_test.go b/router/pkg/pubsub/redis/engine_datasource_test.go index 74b7d564d7..b322c8a60c 100644 --- a/router/pkg/pubsub/redis/engine_datasource_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_test.go @@ -5,36 +5,40 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string - config PublishEventConfiguration + config publishData wantPattern string }{ { name: "simple configuration", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider", Channel: "test-channel", Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + FieldName: "test-field", }, - wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, + wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, }, { name: "with special characters", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Channel: "channel-with-hyphens", Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + FieldName: "test-field", }, - wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, + wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, } @@ -47,11 +51,27 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { } } +func TestPublishData_PublishEventConfiguration(t *testing.T) { + data := publishData{ + Provider: "test-provider", + Channel: "test-channel", + FieldName: "test-field", + } + + evtCfg := &PublishEventConfiguration{ + Provider: data.Provider, + Channel: data.Channel, + FieldName: data.FieldName, + } + + assert.Equal(t, evtCfg, data.PublishEventConfiguration()) +} + func TestRedisPublishDataSource_Load(t *testing.T) { tests := []struct { name string input string - mockSetup func(*MockAdapter) + mockSetup func(*datasource.MockProvider) expectError bool expectedOutput string expectPublished bool @@ -59,11 +79,12 @@ func TestRedisPublishDataSource_Load(t *testing.T) { { name: "successful publish", input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Channel == "test-channel" && - string(event.Event.Data) == `{"message":"hello"}` + event.Channel == "test-channel" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"message":"hello"}`) })).Return(nil) }, expectError: false, @@ -73,8 +94,8 @@ func TestRedisPublishDataSource_Load(t *testing.T) { { name: "publish error", input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("publish error")) }, expectError: false, // The Load method doesn't return the publish error directly expectedOutput: `{"success": false}`, @@ -83,7 +104,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { { name: "invalid input json", input: `{"invalid json":`, - mockSetup: func(m *MockAdapter) {}, + mockSetup: func(m *datasource.MockProvider) {}, expectError: true, expectPublished: false, }, @@ -91,7 +112,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) tt.mockSetup(mockAdapter) dataSource := &PublishDataSource{ @@ -116,7 +137,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { func TestRedisPublishDataSource_LoadWithFiles(t *testing.T) { t.Run("panic on not implemented", func(t *testing.T) { dataSource := &PublishDataSource{ - pubSub: NewMockAdapter(t), + pubSub: datasource.NewMockProvider(t), } assert.Panics(t, func() { diff --git a/router/pkg/pubsub/redis/mocks.go b/router/pkg/pubsub/redis/mocks.go deleted file mode 100644 index 6f6938cdd0..0000000000 --- a/router/pkg/pubsub/redis/mocks.go +++ /dev/null @@ -1,261 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package redis - -import ( - "context" - - mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" -) - -// NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockAdapter(t interface { - mock.TestingT - Cleanup(func()) -}) *MockAdapter { - mock := &MockAdapter{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockAdapter is an autogenerated mock type for the Adapter type -type MockAdapter struct { - mock.Mock -} - -type MockAdapter_Expecter struct { - mock *mock.Mock -} - -func (_m *MockAdapter) EXPECT() *MockAdapter_Expecter { - return &MockAdapter_Expecter{mock: &_m.Mock} -} - -// Publish provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { - ret := _mock.Called(ctx, event) - - if len(ret) == 0 { - panic("no return value specified for Publish") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishEventConfiguration) error); ok { - r0 = returnFunc(ctx, event) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Publish_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Publish' -type MockAdapter_Publish_Call struct { - *mock.Call -} - -// Publish is a helper method to define mock.On call -// - ctx context.Context -// - event PublishEventConfiguration -func (_e *MockAdapter_Expecter) Publish(ctx interface{}, event interface{}) *MockAdapter_Publish_Call { - return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, event)} -} - -func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, event PublishEventConfiguration)) *MockAdapter_Publish_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 PublishEventConfiguration - if args[1] != nil { - arg1 = args[1].(PublishEventConfiguration) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockAdapter_Publish_Call) Return(err error) *MockAdapter_Publish_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, event PublishEventConfiguration) error) *MockAdapter_Publish_Call { - _c.Call.Return(run) - return _c -} - -// Shutdown provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Shutdown(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Shutdown") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' -type MockAdapter_Shutdown_Call struct { - *mock.Call -} - -// Shutdown is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Shutdown(ctx interface{}) *MockAdapter_Shutdown_Call { - return &MockAdapter_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} -} - -func (_c *MockAdapter_Shutdown_Call) Run(run func(ctx context.Context)) *MockAdapter_Shutdown_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) Return(err error) *MockAdapter_Shutdown_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Shutdown_Call { - _c.Call.Return(run) - return _c -} - -// Startup provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Startup(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Startup") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' -type MockAdapter_Startup_Call struct { - *mock.Call -} - -// Startup is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Startup(ctx interface{}) *MockAdapter_Startup_Call { - return &MockAdapter_Startup_Call{Call: _e.mock.On("Startup", ctx)} -} - -func (_c *MockAdapter_Startup_Call) Run(run func(ctx context.Context)) *MockAdapter_Startup_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Startup_Call) Return(err error) *MockAdapter_Startup_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Startup_Call { - _c.Call.Return(run) - return _c -} - -// Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - ret := _mock.Called(ctx, event, updater) - - if len(ret) == 0 { - panic("no return value specified for Subscribe") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { - r0 = returnFunc(ctx, event, updater) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' -type MockAdapter_Subscribe_Call struct { - *mock.Call -} - -// Subscribe is a helper method to define mock.On call -// - ctx context.Context -// - event datasource.SubscriptionEventConfiguration -// - updater datasource.SubscriptionEventUpdater -func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { - return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} -} - -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 datasource.SubscriptionEventConfiguration - if args[1] != nil { - arg1 = args[1].(datasource.SubscriptionEventConfiguration) - } - var arg2 datasource.SubscriptionEventUpdater - if args[2] != nil { - arg2 = args[2].(datasource.SubscriptionEventUpdater) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { - _c.Call.Return(run) - return _c -} diff --git a/router/pkg/pubsub/redis/provider_builder.go b/router/pkg/pubsub/redis/provider_builder.go index 46340934bd..f8814b7d42 100644 --- a/router/pkg/pubsub/redis/provider_builder.go +++ b/router/pkg/pubsub/redis/provider_builder.go @@ -18,7 +18,6 @@ type ProviderBuilder struct { logger *zap.Logger hostName string routerListenAddr string - adapters map[string]Adapter } // NewProviderBuilder creates a new Redis PubSub provider builder @@ -33,7 +32,6 @@ func NewProviderBuilder( logger: logger, hostName: hostName, routerListenAddr: routerListenAddr, - adapters: make(map[string]Adapter), } } @@ -43,8 +41,12 @@ func (b *ProviderBuilder) TypeID() string { } // DataSource creates a Redis PubSub data source for the given event configuration -func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventConfiguration) (datasource.EngineDataSourceFactory, error) { +func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventConfiguration, providers map[string]datasource.Provider) (datasource.EngineDataSourceFactory, error) { providerId := data.GetEngineEventConfiguration().GetProviderId() + provider, ok := providers[providerId] + if !ok { + return nil, fmt.Errorf("failed to get adapter for provider %s with ID %s", b.TypeID(), providerId) + } var eventType EventType switch data.GetEngineEventConfiguration().GetType() { @@ -57,11 +59,12 @@ func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventCo } return &EngineDataSourceFactory{ - RedisAdapter: b.adapters[providerId], fieldName: data.GetEngineEventConfiguration().GetFieldName(), eventType: eventType, channels: data.GetChannels(), providerId: providerId, + RedisAdapter: provider, + logger: b.logger, }, nil } @@ -69,7 +72,6 @@ func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventCo func (b *ProviderBuilder) BuildProvider(provider config.RedisEventSource, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { adapter := NewProviderAdapter(b.ctx, b.logger, provider.URLs, provider.ClusterEnabled, providerOpts) pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, b.logger) - b.adapters[provider.ID] = adapter return pubSubProvider, nil }