diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md new file mode 100644 index 0000000000..436dafe45b --- /dev/null +++ b/adr/cosmo-streams-v1.md @@ -0,0 +1,408 @@ +--- +title: "Cosmo Streams v1" +author: Alessandro Pagnin +date: 2025-07-16 +status: Accepted +--- + +# ADR - Cosmo Streams V1 + +- **Author:** Alessandro Pagnin +- **Date:** 2025-07-16 +- **Status:** Accepted +- **RFC:** ../rfcs/cosmo-streams-v1.md + +## Abstract +This ADR describes new hooks that will be added to the router to support more customizable stream behavior. +The goal is to allow developers to customize the cosmo streams behavior. + +## Decision +The following interfaces will extend the existing logic in the custom modules. +These provide additional control over subscriptions by providing hooks, which are invoked during specific events. + +- `SubscriptionOnStartHandler`: Called once at subscription start. +- `StreamBatchEventHook`: Called each time a batch of events is received from the provider. +- `StreamPublishEventHook`: Called each time a batch of events is going to be sent to the provider. + +```go +// STRUCTURES TO BE ADDED TO PUBSUB PACKAGE +type ProviderType string +const ( + ProviderTypeNats ProviderType = "nats" + ProviderTypeKafka ProviderType = "kafka" + ProviderTypeRedis ProviderType = "redis" +} + +// StreamHookError is used to customize the error messages and the behavior +type StreamHookError struct { + HttpError core.HttpError + CloseSubscription bool +} + +// OperationContext already exists, we just have to add the Variables() method +type OperationContext interface { + Name() string + // the variables are currently not available, so we need to expose them here + Variables() *astjson.Value +} + +// each provider will have its own event type with custom fields +// the StreamEvent interface is used to allow the hooks system to be provider-agnostic +// there could be common fields in future, but for now we don't need them +type StreamEvent interface {} + +// SubscriptionEventConfiguration is the common interface for the subscription event configuration +type SubscriptionEventConfiguration interface { + ProviderID() string + ProviderType() string + // the root field name of the subscription in the schema + RootFieldName() string +} + +// PublishEventConfiguration is the common interface for the publish event configuration +type PublishEventConfiguration interface { + ProviderID() string + ProviderType() string + // the root field name of the mutation in the schema + RootFieldName() string +} + +type SubscriptionOnStartHookContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // WriteEvent writes an event to the stream of the current subscription + // It returns true if the event was written to the stream, false if the event was dropped + WriteEvent(event datasource.StreamEvent) bool +} + +type SubscriptionOnStartHandler interface { + // OnSubscriptionOnStart is called once at subscription start + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +type StreamBatchEventHookContext interface { + // the request context + RequestContext() RequestContext + // the subscription event configuration + SubscriptionEventConfiguration() SubscriptionEventConfiguration +} + +type StreamBatchEventHook interface { + // OnStreamEvents is called each time a batch of events is received from the provider + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +type StreamPublishEventHookContext interface { + // the request context + RequestContext() RequestContext + // the publish event configuration + PublishEventConfiguration() PublishEventConfiguration +} + +type StreamPublishEventHook interface { + // OnPublishEvents is called each time a batch of events is going to be sent to the provider + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. + OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} +``` + +## Example Use Cases + +- **Authorization**: Implementing authorization checks at the start of subscriptions +- **Initial message**: Sending an initial message to clients upon subscription start +- **Data mapping**: Transforming events data from the format that could be used by the external system to/from Federation compatible Router events +- **Event filtering**: Filtering events using custom logic + +## Backwards Compatibility + +The new hooks can be integrated in the router in a fully backwards compatible way. + +When the new module system will be released, the Cosmo Streams hooks: +- will be moved to the `core/hooks.go` file +- will be added to the `hookRegistry` +- will be initialized in the `coreModuleHooks.initCoreModuleHooks` + + +# Example Modules + +__All examples are pseudocode and not tested, but they are as close as possible to the final implementation__ + +## Filter and remap events + +This example will show how to filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +### 1. Add a subscription to the cosmo streams graphql schema + +The developer will start by adding a subscription to the cosmo streams graphql schema. + +```graphql +type Subscription { + employeeUpdates: Employee! @edfs__natsSubscribe(subjects: ["employeeUpdates"], providerId: "my-nats") +} + +type Employee @key(fields: "id", resolvable: false) { + id: Int! @external +} +``` +After publishing the schema, the developer will need to add the module to the cosmo streams engine. + +### 2. Write the custom module + +The developer will need to write the custom module that will be used to subscribe to the `employeeUpdates` subject and filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +```go +package mymodule + +import ( + "encoding/json" + "slices" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" +) + +func init() { + // Register your module here and it will be loaded at router start + core.RegisterModule(&MyModule{}) +} + +type MyModule struct {} + +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { + // check if the provider is nats + if ctx.StreamContext().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the provider id is the one expected by the module + if ctx.StreamContext().ProviderID() != "my-nats" { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return events, nil + } + + // check if the client is authenticated + if ctx.RequestContext().Authentication() == nil { + // if the client is not authenticated, return no events + return events, nil + } + + // check if the client is allowed to subscribe to the stream + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + return events, fmt.Errorf("client is not allowed to subscribe to the stream") + } + + newEvents := make([]core.StreamEvent, 0, len(events)) + + for _, evt := range events { + natsEvent, ok := evt.(*nats.NatsEvent); + if !ok { + newEvents = append(newEvents, evt) + continue + } + + // decode the event data coming from the provider + var dataReceived struct { + EmployeeId string `json:"EmployeeId"` + OtherField string `json:"OtherField"` + } + err := json.Unmarshal(natsEvent.Data, &dataReceived) + if err != nil { + return events, fmt.Errorf("error unmarshalling data: %w", err) + } + + // filter the events based on the client's scopes + if !slices.Contains(clientAllowedEntitiesIds, dataReceived.EmployeeId) { + continue + } + + // prepare the data to send to the client + var dataToSend struct { + Id string `json:"id"` + TypeName string `json:"__typename"` + } + dataToSend.Id = dataReceived.EmployeeId + dataToSend.TypeName = "Employee" + + // marshal the data to send to the client + dataToSendMarshalled, err := json.Marshal(dataToSend) + if err != nil { + return events, fmt.Errorf("error marshalling data: %w", err) + } + + // create the new event + newEvent := &nats.NatsEvent{ + Data: dataToSendMarshalled, + Metadata: natsEvent.Metadata, + } + newEvents = append(newEvents, newEvent) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +// Interface guards +var ( + _ core.StreamBatchEventHook = (*MyModule)(nil) +) +``` + +### 3. Add the provider configuration to the cosmo router +```yaml +version: "1" + +events: + providers: + nats: + - id: my-nats + url: "nats://localhost:4222" +``` + +## Check authorization at subscription start + +This example will show how to check the authorization at subscription start. + +### 1. Add a subscription to the cosmo streams graphql schema + +The developer will start by adding a subscription to the cosmo streams graphql schema. + +```graphql +type Subscription { + employeeUpdates: Employee! @edfs__natsSubscribe(subjects: ["employeeUpdates"], providerId: "my-nats") +} + +type Employee @key(fields: "id", resolvable: false) { + id: Int! @external +} +``` +After publishing the schema, the developer will need to add the module to the cosmo streams engine. + +### 2. Write the custom module + +The developer will need to write the custom module that will be used to check the authorization at subscription start. + +```go +package mymodule + +import ( + "encoding/json" + "slices" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" +) + +func init() { + // Register your module here and it will be loaded at router start + core.RegisterModule(&MyModule{}) +} + +type MyModule struct {} + +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { + // check if the provider is nats + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return nil + } + + // check if the provider id is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { + return nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return nil + } + + // check if the client is authenticated + if ctx.Authentication() == nil { + // if the client is not authenticated, return an error + return &StreamHookError{ + HttpError: core.HttpError{ + Code: http.StatusUnauthorized, + Message: "client is not authenticated", + }, + CloseSubscription: true, + } + } + + // check if the client is allowed to subscribe to the stream + clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["readEmployee"] + if !found { + return &StreamHookError{ + HttpError: core.HttpError{ + Code: http.StatusForbidden, + Message: "client is not allowed to read employees", + }, + CloseSubscription: true, + } + } + + return nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +// Interface guards +var ( + _ core.SubscriptionOnStartHandler = (*MyModule)(nil) +) +``` + +### 3. Add the provider configuration to the cosmo router +```yaml +version: "1" + +events: + providers: + nats: + - id: my-nats + url: "nats://localhost:4222" +``` + +### 4. Build the cosmo router with the custom module + +Build and run the router with the custom module added. + +# Outlook + +## Using AsyncAPI for Event Data Structure + +We could use AsyncAPI specifications to define the event data structure and generate the Go structs automatically. This would make the development of custom modules easier and more maintainable. +We could also generate the AsyncAPI specification from the schema and the events data, to make it easier for external systems to use the events published by cosmo streams engine. + +## Generate hooks from AsyncAPI specifications + +Building on the AsyncAPI integration, we could allow the user to define their streams using AsyncAPI and generate fully typesafe hooks with all events structures generated from the AsyncAPI specification. \ No newline at end of file diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 3473ad212d..6abb2c062e 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -18,7 +18,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in storage.Set(employeeID, isAvailable) err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), - Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID)), + Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, }) if err != nil { @@ -26,7 +26,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in } err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), - Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID)), + Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, }) if err != nil { diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index 2f8ea33149..82a0a7e9f2 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -21,7 +21,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["default"] != nil { err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: myNatsTopic, - Data: []byte(payload), + Event: nats.Event{Data: []byte(payload)}, }) if err != nil { return nil, err @@ -34,7 +34,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["my-nats"] != nil { err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: defaultTopic, - Data: []byte(payload), + Event: nats.Event{Data: []byte(payload)}, }) if err != nil { return nil, err diff --git a/rfc/cosmo-streams-v1.md b/rfc/cosmo-streams-v1.md new file mode 100644 index 0000000000..2a7cd761f1 --- /dev/null +++ b/rfc/cosmo-streams-v1.md @@ -0,0 +1,1079 @@ +# RFC Cosmo Streams V1 + +Based on customer feedback, we've identified the need for more customizable stream behavior. The key areas for customization include: +- **Authorization**: Implementing authorization checks at the start of subscriptions +- **Initial message**: Sending an initial message to clients upon subscription start +- **Data mapping**: Transforming events data from the format that could be used by the external system to/from Federation compatible Router events +- **Event filtering**: Filtering events using custom logic + +Let's explore how we can address each of these requirements. + +## Authorization + +To support authorization, we need a hook that enables the following key decisions: +- Whether the client or user is authorized to initiate the subscription +- Which topics the client is permitted to subscribe to +- Whether the client is allowed to consume an event from the stream (covered by the Event Filtering hook) + +Additionally, a similar mechanism is required for non-stream subscriptions, allowing: +- Custom JWT validation logic (e.g., expiration checks, signature verification, secret handling) +- The ability to reject unauthenticated or unauthorized requests and close the subscription accordingly + +We already allow some customization using `RouterOnRequestHandler`, but it has no access to the stream data. To access this data, we need to add a new hook that will be called immediately before the subscription starts. + +### Example: Check if the client is allowed to subscribe to the stream + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// This is the new hook that will be called once at subscription start +type SubscriptionOnStartHandler interface { + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// already defined in the provider package +type NatsSubscriptionEventConfiguration struct { + ProviderID string `json:"providerId"` + Subjects []string `json:"subjects"` + StreamConfiguration *StreamConfiguration `json:"streamConfiguration,omitempty"` +} + +type StreamHookError struct { + HttpError core.HttpError + CloseSubscription bool +} + +type MyModule struct {} + +// This is a custom function that will be used to check if the client is allowed to subscribe to the stream +func customCheckIfClientIsAllowedToSubscribe(ctx SubscriptionOnStartHookContext) bool { + // check if the field name is the one expected by the module + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdates" { + return true + } + + // get the specific configuration for the provider to make more advanced checks + cfg, ok := ctx.SubscriptionEventConfiguration().(*NatsSubscriptionEventConfiguration) + if !ok { + return true + } + + providerId := cfg.ProviderID + auth := ctx.RequestContext().Authentication() + + // add checks here on client authentication scopes, provider ID, etc. + + return false +} + +// This is the new hook that will be called once at subscription start +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { + // check if the client is allowed to subscribe to the stream + if !customCheckIfClientIsAllowedToSubscribe(ctx) { + // if not, return an error to prevent the subscription from starting + return StreamHookError{ + HttpError: core.NewHttpGraphqlError( + "you should be an admin to subscribe to this or only subscribe to public subscriptions!", + "UNAUTHORIZED", + http.StatusUnauthorized, + ), CloseSubscription: true, + } + } + return nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +Add a new hook to the subscription lifecycle, `SubscriptionOnStartHandler`, that will be called once at subscription start. + +The hook arguments are: +* `ctx SubscriptionOnStartHookContext`: The subscription context, which contains the request context and, optionally, the subscription event configuration, and a method to emit the event to the stream + +`RequestContext` already exists and requires no changes, but `SubscriptionEventConfiguration` is new. + +The hook should return an error if the client is not allowed to subscribe to the stream, preventing the subscription from starting. +The hook should return `nil` if the client is allowed to subscribe to the stream, allowing the subscription to proceed. + +The hook can return a `SubscriptionHookError` to customize the error messages and the behavior on the subscription. + +I evaluated the possibility of adding the `SubscriptionContext` to the request context and using it within one of the existing hooks, but it would be difficult to build the subscription context without executing the pubsub code. + +The `SubscriptionEventConfiguration()` contains the subscription configuration as used by the provider. This allows the hooks system to be provider-agnostic, so adding a new provider will not require changes to the hooks system. To use specific fields, the hook can cast the configuration to the specific type for the provider. +The `WriteEvent()` method is new and allows the hook to emit the event to the stream. + +## Initial Message + +When starting a subscription, the client sends a query to the server containing the operation name and variables. The client must then wait for the broker to send the initial message. This waiting period can lead to a poor user experience, as the client cannot display anything until the initial message is received. To address this, we can emit an initial message on subscription start. + +To emit an initial message on subscription start, we need access to the stream context (to get the provider type and ID) and the query that the client sent. The variables are particularly important, as they allow the module to use them in the initial message. For example, if someone starts a subscription with employee ID 100 as a variable, the custom module can include that ID in the initial message. + +### Example + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// This is the new hook that will be called once at stream start +type SubscriptionOnStartHandler interface { + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} + +type MyModule struct {} + +// This is the new hook that will be called once at subscription start +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { + // get the operation name and variables that we need + opName := ctx.RequestContext().Operation().Name() + opVarId := ctx.RequestContext().Operation().Variables().GetInt("id") + + // check if the provider ID is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-provider-id" { + return nil + } + + //check if the provider type is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return nil + } + + // check if the operation name is the one expected by the module + if opName == "employeeSub" { + // create the event to emit using the operation variables + evt := &NatsEvent{ + Data: []byte(fmt.Sprintf("{\"id\": \"%d\", \"__typename\": \"Employee\"}", opVarId)), + Metadata: map[string]string{ + "entity-id": fmt.Sprintf("%d", opVarId), + }, + } + // emit the event to the stream, that will be received only by the client that subscribed to the stream + ctx.WriteEvent(evt) + } + return nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +Using the new `SubscriptionOnStart` hook that we introduced for the previous requirement, we can emit the initial message on subscription start. We will also need access to operation variables, which are currently not available in the request context. + +To emit the message, I propose adding a new method to the stream context, `WriteEvent`, which will emit the event to the stream at the lowest level. The message will pass through all hooks, making it behave like any other event received from the provider. The message will be received only by the client that subscribed to the stream, and not by the other clients that subscribed to the same stream. + +The `StreamEvent` contains the data as used by the provider. This allows the hooks system to be provider-agnostic, so adding a new provider will not require changes to the hooks system. To use events, the hook has to cast the event to the specific type for the provider. + +This change will require adding a new type in each provider package to represent the event with additional fields (metadata, etc.). This is a significant change, but it is necessary to support additional data in events, anyway, even if we don't expose them to the custom modules. + +Emitting the initial message with this hook ensures that the client will receive the message before the first event from the provider is received. + +## Data Mapping + +The current approach for emitting and reading data from the stream is not flexible enough. We need to be able to map data from an external format to the internal format, and vice versa. + +Also, different providers can have different additional fields other than the message body. + +As an example: +- NATS provider can have a `Metadata` field +- Kafka provider can have a `Headers` and `Key` fields + +And this additional fields could be an important part of integrating with external systems. + +### Example 1: Rewrite the event received from the provider to a format that is usable by Cosmo streams + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} +type KafkaEvent struct { + Key []byte + Data json.RawMessage + Headers map[[]byte][]byte +} + +// StreamBatchEventHook processes a batch of inbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamBatchEventHook interface { + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +type MyModule struct {} + +// This is the new hook that will be called each time a batch of events is received from the provider +func (m *MyModule) OnStreamEvents( + ctx StreamBatchEventHookContext, + events []StreamEvent, +) ([]StreamEvent, error) { + // check if the provider ID is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-provider-id" { + return events, nil + } + + // check if the provider type is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "topic-with-internal-data-format" { + return events, nil + } + + // create a new slice of events that we will return with the events with the new format + newEvents := make([]StreamEvent, 0, len(events)) + for _, evt := range events { + // check if the event is the one expected by the module + if natsEvent, ok := evt.(*NatsEvent); ok { + // here you can umarshal the old data and map it to the new format + // for example: + // var dataReceived struct { + // EmployeeName string `json:"EmployeeName"` + // } + // err := json.Unmarshal(natsEvent.Data, &dataReceived) + + // if we have to extract the data from the metadata fields, we can do it like this: + entityId := natsEvent.Metadata["entity-id"] + entityType := natsEvent.Metadata["entity-type"] + // and prepare the new event with the data inside + newDataFormat, _ := json.Marshal(map[string]string{ + "id": entityId, + "name": dataReceived.EmployeeName, + "__typename": entityType, + }) + + // create the new event + newEvent := &NatsEvent{ + Data: newDataFormat, + Metadata: natsEvent.Metadata, + } + + // or for Kafka we would have something like: + // newEvent := &KafkaEvent{ + // Key: kafkaEvent.Key, + // Data: newDataFormat, + // Headers: kafkaEvent.Headers, + // } + + // add the new event to the slice of events to return + newEvents = append(newEvents, newEvent) + continue + } + // add the original event to the slice of events to return + newEvents = append(newEvents, evt) + } + + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Example 2: Rewrite the event before emitting it to the provider to a format that is usable by external systems + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// StreamPublishEventHook processes a batch of outbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamPublishEventHook interface { + OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} + +type MyModule struct {} + +// This is the new hook that will be called each time a batch of events is going to be sent to the provider +func (m *MyModule) OnPublishEvents( + ctx StreamPublishEventHookContext, + events []StreamEvent, +) ([]StreamEvent, error) { + // check if the provider ID is the one expected by the module + if ctx.PublishEventConfiguration().ProviderID() != "my-provider-id" { + return events, nil + } + + // check if the provider type is the one expected by the module + if ctx.PublishEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.PublishEventConfiguration().(*nats.PublishAndRequestEventConfiguration) + if natsConfig.Subject != "topic-with-internal-data-format" { + return events, nil + } + + // create a new slice of events that we will return with the events with the new format + newEvents := make([]StreamEvent, 0, len(events)) + for _, evt := range events { + // check if the event is the one expected by the module + if natsEvent, ok := evt.(*NatsEvent); ok { + // here you can umarshal the old data and map it to the new format + // for example: + // var dataReceived struct { + // EmployeeId string `json:"EmployeeId"` + // } + // err := json.Unmarshal(natsEvent.Data, &dataReceived) + + // create the new event + newEvent := &NatsEvent{ + Data: dataToSendMarshalled, + Metadata: map[string]string{ + "entity-id": dataReceived.Id, + "entity-domain": "employee", + }, + } + + // add the new event to the slice of events to return + newEvents = append(newEvents, newEvent) + continue + } + newEvents = append(newEvents, evt) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +Add two new hooks to the stream lifecycle: `StreamBatchEventHook` and `StreamPublishEventHook`. +The `StreamBatchEventHook` will be called each time a batch of events is received from the provider, making it possible to rewrite, filter or split the event data to a format usable within Cosmo streams. +The `StreamPublishEventHook` will be called each time a batch of events is going to be sent to the provider, making it possible to rewrite, filter or split the event data to a format usable by external systems. + +The hook arguments of `StreamBatchEventHook` are: +* `ctx StreamBatchEventHookContext`: The stream context, which contains the provider ID and the subscription configuration +* `events []StreamEvent`: The events received from the provider + +The hook will return a new slice of events that will be used to emit the events to the client. +The hook will also return an error if one of the events cannot be processed, preventing the events from being processed. + +The hook arguments of `StreamPublishEventHook` are: +* `ctx StreamPublishEventHookContext`: The stream context, which contains the provider ID and the publish configuration +* `events []StreamEvent`: The events that are going to be sent to the provider + +The hook will return a new slice of events that will be used to emit the events to the provider. +The hook will also return an error if one of the events cannot be processed, preventing the events from being processed. + +#### Do we need two new hooks? + +Another possible solution for mapping outward data would be to use the existing middleware hooks `RouterOnRequestHandler` or `RouterMiddlewareHandler` to intercept the mutation, access the stream context, and emit the event to the stream. However, this would require exposing a stream context in the request lifecycle, which is difficult. It would also require coordination to ensure that an event emitted on the stream is sent only after the subscription starts. + +Additionally, this solution is not usable on the subscription side of streams: +- The middleware hook is linked to the request lifecycle, making it difficult to use them to rewrite event data +- When we use the streams feature internally, we will still need to provide a way to rewrite event data, requiring a new hook in the subscription lifecycle + +Therefore, I believe the best solution is to add a new hooks to the stream lifecycle. + +## Event Filtering + +We need to allow customers to filter events based on custom logic. We currently only provide declarative filters, which are quite limited. +The event filtering hook will also be useful to implement the authorization logic at the events level. + +### Example: Filter events based on stream configuration and client's scopes + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// StreamBatchEventHook processes a batch of inbound stream events. +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamBatchEventHook interface { + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} + +type MyModule struct {} + +// This is the new hook that will be called each time a batch of events is received from the provider +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) { + // check if the provider ID is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-provider-id" { + return events, nil + } + + // check if the provider type is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "topic-with-internal-data-format" { + return events, nil + } + + // create a new slice of events that we will return with the events that are allowed to be received by the client + newEvents := make([]StreamEvent, 0, len(events)) + + + if ctx.RequestContext().Authentication() == nil { + // if the client is not authenticated, return no events + return newEvents, nil + } + + // get the client's allowed entities IDs + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + // if the client doesn't have allowed entities IDs, return the original events + return newEvents, nil + } + + for _, evt := range events { + // check if the event is the one expected by the module + if natsEvent, ok := evt.(*NatsEvent); ok { + // check the entity ID in the metadata + idHeader, ok := natsEvent.Metadata["entity-id"] + if !ok { + continue + } + // check if the entity ID is in the client's allowed entities IDs + if slices.Contains(clientAllowedEntitiesIds, idHeader) { + // add the event to the slice of events to return because the client is allowed to receive it + newEvents = append(newEvents, evt) + } + } + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +We can use the new `StreamBatchEventHook` to filter events based on the stream configuration and the client's scopes. + +The hook arguments are: +* `ctx StreamBatchEventHookContext`: The stream context, which contains the ID of the stream and the request context +* `events []StreamEvent`: The events received from the provider or the events that are going to be sent to the provider + +The hook will return a new slice of events that will be used to emit the events to the client or to the provider. +The hook will also return an error if one of the events cannot be processed, preventing the event from being processed. + +## Architecture + +With this proposal, we will add two new hooks to stream lifecycles and other hooks to the subscription lifecycle. + +### Subscription Lifecycle +``` +Start subscription + │ + └─▶ core.SubscriptionOnStartHandler (Early return, Custom Authentication Logic) + │ + └─▶ "Subscription started" +``` + +### Stream Lifecycle + +``` +One or more batched events are received from the provider + │ + └─▶ core.StreamBatchEventHook (Data mapping, Filtering) + │ + └─▶ "Deliver events to client" + +One or more batched events are published to the provider + │ + └─▶ core.StreamPublishEventHook (Data mapping, Filtering) + │ + └─▶ "Send event to provider" +``` + +### Data Flow + +We will need to change the format of the event data sent within the router. Today we use the data that will be sent to the provider directly, but we will need to add a structure where we can include additional fields (metadata, etc.) in the event. + +## Implementation Details + +The implementation of this solution will only require changes in the Cosmo repository, without any changes to the engine. This implementation will not require additional changes to the hooks structures each time a new provider is added. + +## Considerations and Risks + +- All hooks could be called in parallel, so we need to handle concurrency carefully +- All hook implementations could raise a panic, so we need to implement proper error handling +- Especially the casting of the event to the specific type for the provider could raise a panic if the event is not of the expected type and the developer is not using the type check +- We should add metrics to track how much time is spent in each hook, to help customers identify slow hooks + +## Development workflow of subscription with custom modules + +Lets build an example of how the development workflow would look like for a developer that want to add a custom module to the cosmo streams engine. The idea is to build a module that will be used to subscribe to the `employeeUpdates` subject and filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +I'll show the workflow for a developer that wants to customize the subscription, but the same workflow can be applied to the mutation. + +### Add a subscription to the cosmo streams graphql schema + +The developer will start by adding a subscription to the cosmo streams graphql schema. +```graphql +type Subscription { + employeeUpdates(): Employee! @edfs__natsSubscribe(subjects: ["employeeUpdates"], providerId: "my-nats") +} + +type Employee @key(fields: "id", resolvable: false) { + id: Int! @external +} +``` +After publishing the schema, the developer will need to add the module to the cosmo streams engine. + +### 2. Write the custom module + +The developer will need to write the custom module that will be used to subscribe to the `employeeUpdates` subject and filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +```go +package mymodule + +import ( + "encoding/json" + "slices" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" +) + +func init() { + // Register your module here and it will be loaded at router start + core.RegisterModule(&MyModule{}) +} + +type MyModule struct {} + +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { + // check if the provider is nats + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the provider id is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return events, nil + } + + // check if the client is authenticated + if ctx.RequestContext().Authentication() == nil { + // if the client is not authenticated, return no events + return events, nil + } + + // check if the client is allowed to subscribe to the stream + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + return events, fmt.Errorf("client is not allowed to subscribe to the stream") + } + + newEvents := make([]core.StreamEvent, 0, len(events)) + + for _, evt := range events { + natsEvent, ok := evt.(*nats.NatsEvent); + if !ok { + newEvents = append(newEvents, evt) + continue + } + + // decode the event data coming from the provider + var dataReceived struct { + EmployeeId string `json:"EmployeeId"` + OtherField string `json:"OtherField"` + } + err := json.Unmarshal(natsEvent.Data, &dataReceived) + if err != nil { + return events, fmt.Errorf("error unmarshalling data: %w", err) + } + + // filter the events based on the client's scopes + if !slices.Contains(clientAllowedEntitiesIds, dataReceived.EmployeeId) { + continue + } + + // prepare the data to send to the client + var dataToSend struct { + Id string `json:"id"` + TypeName string `json:"__typename"` + } + dataToSend.Id = dataReceived.EmployeeId + dataToSend.TypeName = "Employee" + + // marshal the data to send to the client + dataToSendMarshalled, err := json.Marshal(dataToSend) + if err != nil { + return events, fmt.Errorf("error marshalling data: %w", err) + } + + // create the new event + newEvent := &nats.NatsEvent{ + Data: dataToSendMarshalled, + Metadata: natsEvent.Metadata, + } + newEvents = append(newEvents, newEvent) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +// Interface guards +var ( + _ core.StreamBatchEventHook = (*MyModule)(nil) +) +``` + +### 3. Add the provider configuration to the cosmo router +```yaml +version: "1" + +events: + providers: + nats: + - id: my-nats + url: "nats://localhost:4222" +``` + +### 4. Build the cosmo router with the custom module + +Build and run the router with the custom module added. + +## Appendix 1, new data structures + +```go +// NEW HOOKS + +// SubscriptionOnStartHandler is a hook that is called once at subscription start +// it is used to validate if the client is allowed to subscribe to the stream +// if returns an error, the subscription will not start +type SubscriptionOnStartHandler interface { + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// StreamBatchEventHook processes a batch of inbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamBatchEventHook interface { + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// StreamPublishEventHook processes a batch of outbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamPublishEventHook interface { + OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// NEW INTERFACES +type SubscriptionEventConfiguration interface { + ProviderID() string + ProviderType() string + RootFieldName() string // the root field name of the subscription in the schema +} + +type PublishEventConfiguration interface { + ProviderID() string + ProviderType() string + RootFieldName() string // the root field name of the mutation in the schema +} + +type StreamEvent interface {} + +type StreamBatchEventHookContext interface { + RequestContext() RequestContext + SubscriptionEventConfiguration() SubscriptionEventConfiguration +} + +type StreamPublishEventHookContext interface { + RequestContext() RequestContext + PublishEventConfiguration() PublishEventConfiguration +} + +type SubscriptionOnStartHookContext interface { + RequestContext() RequestContext + SubscriptionEventConfiguration() SubscriptionEventConfiguration + WriteEvent(event core.StreamEvent) +} + +// ALREADY EXISTING INTERFACES THAT WILL BE UPDATED +type OperationContext interface { + Name() string + // the variables are currently not available, so we need to add them here + Variables() *astjson.Value +} + +// NEW STRUCTURES +// StreamHookError is used to customize the error messages and the behavior +type StreamHookError struct { + HttpError core.HttpError + CloseSubscription bool +} + +func (e StreamHookError) Error() string { + return e.HttpError.Message() +} + +// STRUCTURES TO BE ADDED TO PUBSUB PACKAGE +type ProviderType string +const ( + ProviderTypeNats ProviderType = "nats" + ProviderTypeKafka ProviderType = "kafka" + ProviderTypeRedis ProviderType = "redis" +} + +``` + +## Appendix 2, Using AsyncAPI for Event Data Structure + +As a side note, it is important to find ways to document the data that is arriving and going out of the cosmo streams engine. This could allow some automatic code generation starting from the schema and the events data. +As an example, we are going to explore how AsyncAPI could be used to generate the data structures for the custom modules and assure the messages format. + +### Example: AsyncAPI Integration for Custom Module Development + +We propose integrating AsyncAPI specifications with Cosmo streams to generate type-safe Go structs that can be used in custom modules. This would significantly improve the developer experience by providing: + +1. **Type Safety**: Generated structs prevent runtime errors from incorrect field access +2. **Documentation**: AsyncAPI specs serve as living documentation for event schemas +3. **Code Generation**: Automatic generation of Go structs from AsyncAPI specifications +4. **IDE Support**: Better autocomplete and error detection in development environments + +### AsyncAPI Specification Example + +So if we have as an example the following AsyncAPI specification: + +```yaml +# employee-events.asyncapi.yaml +asyncapi: 3.0.0 +info: + title: Employee Events API + version: 1.0.0 + description: Events related to employee updates in the system + +channels: + externalSystemEmployeeUpdates: + messages: + EmployeeUpdated: + $ref: '#/components/messages/EmployeeUpdated' + +components: + messages: + ExternalSystemEmployeeUpdated: + name: ExternalSystemEmployeeUpdated + title: External System Employee Updated Event + summary: Sent when an employee is updated in the external system + contentType: application/json + payload: + $ref: '#/components/schemas/ExternalSystemEmployeeFormat' + + schemas: + ExternalSystemEmployeeFormat: + type: object + description: Employee data as received from external systems + properties: + EmployeeId: + type: string + description: Unique identifier for the employee + EmployeeName: + type: string + description: Full name of the employee + EmployeeEmail: + type: string + format: email + description: Email address of the employee + OtherField: + type: string + description: Additional field from external system + required: + - EmployeeId + - EmployeeName + - EmployeeEmail +``` + +### Code Generation Workflow + +We could provide a CLI command to WGC to generate the Go structs from AsyncAPI specifications: + +```bash +# Generate Go structs from AsyncAPI spec +wgc streams generate -i employee-events.asyncapi.yaml -o ./generated/events.go -p events +``` + +Before generating the code, we could add to the data that cosmo streams is expecting to receive and send. +```yaml +# cosmo-streams-events.asyncapi.yaml +asyncapi: 3.0.0 +info: + title: Cosmo Streams Employee Events API + version: 1.0.0 + +channels: + cosmoStreamsEmployeeUpdates: + messages: + CosmoStreamsEmployeeUpdated: + $ref: '#/components/messages/CosmoStreamsEmployeeUpdated' + +components: + messages: + CosmoStreamsEmployeeUpdated: + name: CosmoStreamsEmployeeUpdated + title: Cosmo Streams Employee Updated Event + summary: Event published when updating an employee in the cosmo streams + contentType: application/json + payload: + $ref: '#/components/schemas/EmployeeInternalFormat' + + schemas: + CosmoStreamsEmployeeUpdated: + type: object + description: Employee data as used internally by Cosmo streams + properties: + id: + type: string + description: Unique identifier for the employee + name: + type: string + description: Full name of the employee + email: + type: string + format: email + description: Email address of the employee + required: + - id + - __typename +``` + +This command would be a wrapper around asyncapi modelina, and with some additional logic to extract the internal events format from the schema SDL. + +This would generate a second async api specification and Go code like: + +```go +// generated/events.go +package events + +import ( + "encoding/json" + "time" +) + +// ExternalSystemEmployeeUpdated represents employee data as received from external systems +type ExternalSystemEmployeeUpdated struct { + EmployeeId string `json:"EmployeeId"` + EmployeeName string `json:"EmployeeName"` + EmployeeEmail string `json:"EmployeeEmail"` + OtherField string `json:"OtherField"` +} + +// EmployeeInternalFormat represents employee data as used internally by Cosmo streams +type CosmoStreamsEmployeeUpdated struct { + Id string `json:"id"` + Name string `json:"name"` + Email string `json:"email"` +} +``` + +We could than encourage the developers to add conversions in a file in the same package of the generated file, like so: + +```go +// generated/events.go +package events + +import ( + "encoding/json" + "time" +) + +func ExternalSystemEmployeeUpdatedToCosmoStreamsEmployeeUpdated(e *ExternalSystemEmployeeUpdated) *CosmoStreamsEmployeeUpdated { + return &CosmoStreamsEmployeeUpdated{ + Id: e.EmployeeId, + Name: e.EmployeeName, + Email: e.EmployeeEmail, + } +} + +``` + +Also, external systems could use the generated async api specification to generate the code for the events that they are sending/receiving to/from cosmo streams. + +### Enhanced Custom Module Development + +With generated structs, the custom module code becomes more maintainable and type-safe: + +```go +package mymodule + +import ( + "encoding/json" + "fmt" + "slices" + + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" + "your-project/generated/genevents" +) + +type MyModule struct {} + +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { + return events, nil + } + + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return events, nil + } + + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + return events, fmt.Errorf("client is not allowed to subscribe to the stream") + } + + for _, evt := range events { + natsEvent, ok := evt.(*nats.NatsEvent); + if !ok { + newEvents = append(newEvents, evt) + continue + } + + // Use generated struct for type-safe deserialization + var dataReceived genevents.ExternalSystemEmployeeUpdated + err := json.Unmarshal(natsEvent.Data, &dataReceived) + if err != nil { + return events, fmt.Errorf("error unmarshalling data: %w", err) + } + + // Convert to internal format using generated method + dataToSend := genevents.ExternalSystemEmployeeUpdatedToCosmoStreamsEmployeeUpdated(&dataReceived) + + // Marshal using the generated struct + dataToSendMarshalled, err := json.Marshal(dataToSend) + if err != nil { + return events, fmt.Errorf("error marshalling data: %w", err) + } + + // Create new event + newEvent := &nats.NatsEvent{ + Data: dataToSendMarshalled, + } + newEvents = append(newEvents, newEvent) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +var _ core.StreamBatchEventHook = (*MyModule)(nil) +``` + +### Considerations + +The developers would need to regenerate the code each time the AsyncAPI specification changes or the schema SDL changes. + +### Outlook + +In a second step, we could: +- allow the user to define their streams using AsyncAPI +- generate fully typesafe hooks with all events structures generated from the AsyncAPI specification \ No newline at end of file diff --git a/router-tests/go.mod b/router-tests/go.mod index ffe6d65377..479d44590c 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -27,7 +27,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20250912064154-106e871ee32e github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 @@ -209,7 +209,7 @@ replace ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects => ../demo/pkg/subgraphs/projects github.com/wundergraph/cosmo/router => ../router github.com/wundergraph/cosmo/router-plugin => ../router-plugin -// github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 +//github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 ) replace github.com/hashicorp/consul/sdk => github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 diff --git a/router-tests/go.sum b/router-tests/go.sum index c6ec48801f..947f5ba76c 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -354,6 +354,8 @@ github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoT github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 h1:VCfCX/xmpBGQLhTHJMHLugzJrXJk/smjLRAEruCI0HY= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router-tests/modules/start-subscription/module.go b/router-tests/modules/start-subscription/module.go new file mode 100644 index 0000000000..fd5a9e0088 --- /dev/null +++ b/router-tests/modules/start-subscription/module.go @@ -0,0 +1,61 @@ +package start_subscription + +import ( + "net/http" + + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" +) + +const myModuleID = "startSubscriptionModule" + +type StartSubscriptionModule struct { + Logger *zap.Logger + Callback func(ctx core.SubscriptionOnStartHookContext) error + CallbackOnOriginResponse func(response *http.Response, ctx core.RequestContext) *http.Response +} + +func (m *StartSubscriptionModule) Provision(ctx *core.ModuleContext) error { + // Assign the logger to the module for non-request related logging + m.Logger = ctx.Logger + + return nil +} + +func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHookContext) error { + + m.Logger.Info("SubscriptionOnStart Hook has been run") + + if m.Callback != nil { + return m.Callback(ctx) + } + + return nil +} + +func (m *StartSubscriptionModule) OnOriginResponse(response *http.Response, ctx core.RequestContext) *http.Response { + if m.CallbackOnOriginResponse != nil { + return m.CallbackOnOriginResponse(response, ctx) + } + + return response +} + +func (m *StartSubscriptionModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &StartSubscriptionModule{} + }, + } +} + +// Interface guard +var ( + _ core.SubscriptionOnStartHandler = (*StartSubscriptionModule)(nil) + _ core.EnginePostOriginHandler = (*StartSubscriptionModule)(nil) +) diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go new file mode 100644 index 0000000000..ad286d54ef --- /dev/null +++ b/router-tests/modules/start_subscription_test.go @@ -0,0 +1,664 @@ +package module_test + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + start_subscription "github.com/wundergraph/cosmo/router-tests/modules/start-subscription" + "go.uber.org/zap/zapcore" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" +) + +func TestStartSubscriptionHook(t *testing.T) { + t.Parallel() + + t.Run("Test StartSubscription hook is called", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription write event works", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { + return nil + } + ctx.WriteEvent(&kafka.Event{ + Key: []byte("1"), + Data: []byte(`{"id": 1, "__typename": "Employee"}`), + }) + return nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription with close to true", func(t *testing.T) { + t.Parallel() + + callbackCalled := make(chan bool) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + callbackCalled <- true + return core.NewStreamHookError(nil, "subscription closed", http.StatusOK, "") + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionArgsCh := make(chan kafkaSubscriptionArgs, 1) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + <-callbackCalled + xEnv.WaitForSubscriptionCount(0, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + + require.Len(t, subscriptionArgsCh, 1) + subscriptionArgs := <-subscriptionArgsCh + require.Error(t, subscriptionArgs.errValue) + require.Empty(t, subscriptionArgs.dataValue) + }) + }) + + t.Run("Test StartSubscription write event sends event only to the subscription", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + employeeId := ctx.Operation().Variables().GetInt64("employeeID") + if employeeId != 1 { + return nil + } + ctx.WriteEvent(&kafka.Event{ + Data: []byte(`{"id": 1, "__typename": "Employee"}`), + }) + return nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscription struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + vars2 := map[string]interface{}{ + "employeeID": 1, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscription, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- kafkaSubscriptionArgs{ + dataValue: []byte{}, + errValue: errors.New("should not be called"), + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + subscriptionTwoArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionTwoID, err := client.Subscribe(&subscription, vars2, func(dataValue []byte, errValue error) error { + subscriptionTwoArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionTwoID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(2, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionTwoArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 2) + t.Cleanup(func() { + require.Len(t, subscriptionOneArgsCh, 0) + }) + }) + }) + + t.Run("Test StartSubscription error is propagated to the client", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + return core.NewStreamHookError(errors.New("test error"), "test error", http.StatusLoopDetected, http.StatusText(http.StatusLoopDetected)) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscription struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 1, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscription, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + // Wait for the subscription to be closed + xEnv.WaitForSubscriptionCount(0, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + var graphqlErrs graphql.Errors + require.ErrorAs(t, args.errValue, &graphqlErrs) + statusCode, ok := graphqlErrs[0].Extensions["statusCode"].(float64) + require.True(t, ok, "statusCode is not a float64") + require.Equal(t, http.StatusLoopDetected, int(statusCode)) + require.Equal(t, http.StatusText(http.StatusLoopDetected), graphqlErrs[0].Extensions["code"]) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + t.Cleanup(func() { + require.Len(t, subscriptionOneArgsCh, 0) + }) + }) + }) + + t.Run("Test StartSubscription hook is called for engine subscription", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + var subscriptionCountEmp struct { + CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "max": 1, + "interval": 200, + } + subscriptionOneID, err := client.Subscribe(&subscriptionCountEmp, vars, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription hook is called for engine subscription and write event works", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + ctx.WriteEvent(&core.EngineEvent{ + Data: []byte(`{"data":{"countEmp":1000}}`), + }) + return nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + var subscriptionCountEmp struct { + CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "max": 0, + "interval": 0, + } + + type subscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan subscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionCountEmp, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- subscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args subscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"countEmp": 1000}`, string(args.dataValue)) + }) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args subscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"countEmp": 0}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription hook is called, return StreamHookError, response on OnOriginResponse should still be set", func(t *testing.T) { + t.Parallel() + originResponseCalled := make(chan *http.Response, 1) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + return core.NewStreamHookError(errors.New("subscription closed"), "subscription closed", http.StatusOK, "NotFound") + }, + CallbackOnOriginResponse: func(response *http.Response, ctx core.RequestContext) *http.Response { + originResponseCalled <- response + return response + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionCountEmp struct { + CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "max": 0, + "interval": 0, + } + + type subscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan subscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionCountEmp, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- subscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args subscriptionArgs) { + require.Error(t, args.errValue) + require.Empty(t, args.dataValue) + }) + + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + require.Empty(t, originResponseCalled) + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) +} diff --git a/router/.mockery.yml b/router/.mockery.yml index c835d3af50..558bca2185 100644 --- a/router/.mockery.yml +++ b/router/.mockery.yml @@ -13,10 +13,11 @@ template-schema: '{{.Template}}.schema.json' packages: github.com/wundergraph/cosmo/router/pkg/pubsub/datasource: interfaces: - ProviderLifecycle: + Lifecycle: ProviderBuilder: EngineDataSourceFactory: Provider: + SubscriptionEventUpdater: github.com/wundergraph/cosmo/router/pkg/pubsub/nats: interfaces: Adapter: diff --git a/router/core/errors.go b/router/core/errors.go index 7f8df34da2..44e05f327b 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -35,6 +35,7 @@ const ( errorTypeInvalidWsSubprotocol errorTypeEDFSInvalidMessage errorTypeMergeResult + errorTypeStreamHookError ) type ( @@ -89,6 +90,10 @@ func getErrorType(err error) errorType { if errors.As(err, &mergeResultErr) { return errorTypeMergeResult } + var streamHookErr *StreamHookError + if errors.As(err, &streamHookErr) { + return errorTypeStreamHookError + } return errorTypeUnknown } diff --git a/router/core/executor.go b/router/core/executor.go index 437634ea8c..e29ed7682b 100644 --- a/router/core/executor.go +++ b/router/core/executor.go @@ -35,6 +35,8 @@ type ExecutorConfigurationBuilder struct { subscriptionClientOptions *SubscriptionClientOptions instanceData InstanceData + + subscriptionHooks subscriptionHooks } type Executor struct { @@ -216,7 +218,7 @@ func (b *ExecutorConfigurationBuilder) buildPlannerConfiguration(ctx context.Con routerEngineCfg.Execution.EnableSingleFlight, routerEngineCfg.Execution.EnableNetPoll, b.instanceData, - ), b.logger) + ), b.logger, b.subscriptionHooks) // this generates the plan config using the data source factories from the config package planConfig, providers, err := loader.Load(engineConfig, subgraphs, routerEngineCfg, pluginsEnabled) diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index b73742a91d..70f3e7917c 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -31,8 +31,9 @@ import ( ) type Loader struct { - ctx context.Context - resolver FactoryResolver + ctx context.Context + resolver FactoryResolver + subscriptionHooks subscriptionHooks // includeInfo controls whether additional information like type usage and field usage is included in the plan de includeInfo bool logger *zap.Logger @@ -190,12 +191,13 @@ func (d *DefaultFactoryResolver) InstanceData() InstanceData { return d.instanceData } -func NewLoader(ctx context.Context, includeInfo bool, resolver FactoryResolver, logger *zap.Logger) *Loader { +func NewLoader(ctx context.Context, includeInfo bool, resolver FactoryResolver, logger *zap.Logger, subscriptionHooks subscriptionHooks) *Loader { return &Loader{ - ctx: ctx, - resolver: resolver, - includeInfo: includeInfo, - logger: logger, + ctx: ctx, + resolver: resolver, + includeInfo: includeInfo, + logger: logger, + subscriptionHooks: subscriptionHooks, } } @@ -416,6 +418,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod } } + subscriptionOnStartFns := make([]graphql_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart)) + for i, fn := range l.subscriptionHooks.onStart { + subscriptionOnStartFns[i] = NewEngineSubscriptionOnStartHook(fn) + } customConfiguration, err := graphql_datasource.NewConfiguration(graphql_datasource.ConfigurationInput{ Fetch: &graphql_datasource.FetchConfiguration{ URL: fetchUrl, @@ -429,6 +435,7 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod ForwardedClientHeaderNames: forwardedClientHeaders, ForwardedClientHeaderRegularExpressions: forwardedClientRegexps, WsSubProtocol: wsSubprotocol, + StartupHooks: subscriptionOnStartFns, }, SchemaConfiguration: schemaConfiguration, CustomScalarTypeFields: customScalarTypeFields, @@ -470,6 +477,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod } } + subscriptionOnStartFns := make([]pubsub_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart)) + for i, fn := range l.subscriptionHooks.onStart { + subscriptionOnStartFns[i] = NewPubSubSubscriptionOnStartHook(fn) + } factoryProviders, factoryDataSources, err := pubsub.BuildProvidersAndDataSources( l.ctx, routerEngineConfig.Events, @@ -478,6 +489,9 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod pubSubDS, l.resolver.InstanceData().HostName, l.resolver.InstanceData().ListenAddress, + pubsub.Hooks{ + SubscriptionOnStart: subscriptionOnStartFns, + }, ) if err != nil { return nil, providers, err diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 95f66c0ac5..c1330f77f5 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1189,6 +1189,7 @@ func (s *graphServer) buildGraphMux( EnableTraceClient: enableTraceClient, CircuitBreaker: s.circuitBreakerManager, }, + subscriptionHooks: s.subscriptionHooks, } executor, providers, err := ecb.Build( diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index c494fff4ce..f387d73e6c 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -400,6 +400,22 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } + case errorTypeStreamHookError: + var streamHookErr *StreamHookError + if !errors.As(err, &streamHookErr) { + response.Errors[0].Message = "Internal server error" + return + } + response.Errors[0].Message = streamHookErr.Message() + if streamHookErr.Code() != "" || streamHookErr.StatusCode() != 0 { + response.Errors[0].Extensions = &Extensions{ + Code: streamHookErr.Code(), + StatusCode: streamHookErr.StatusCode(), + } + } + if isHttpResponseWriter { + httpWriter.WriteHeader(streamHookErr.StatusCode()) + } } if ctx.TracingOptions.Enable && ctx.TracingOptions.IncludeTraceOutputInResponseExtensions { diff --git a/router/core/plan_generator.go b/router/core/plan_generator.go index 1026fbe592..4c265a67d5 100644 --- a/router/core/plan_generator.go +++ b/router/core/plan_generator.go @@ -323,7 +323,7 @@ func (pg *PlanGenerator) loadConfiguration(routerConfig *nodev1.RouterConfig, lo httpClient: http.DefaultClient, streamingClient: http.DefaultClient, subscriptionClient: subscriptionClient, - }, logger) + }, logger, subscriptionHooks{}) // this generates the plan configuration using the data source factories from the config package planConfig, _, err := loader.Load(routerConfig.GetEngineConfig(), routerConfig.GetSubgraphs(), &routerEngineConfig, false) // TODO: configure plugins diff --git a/router/core/router.go b/router/core/router.go index 3432528a7a..9f07bd723a 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -666,6 +666,10 @@ func (r *Router) initModules(ctx context.Context) error { } } + if handler, ok := moduleInstance.(SubscriptionOnStartHandler); ok { + r.subscriptionHooks.onStart = append(r.subscriptionHooks.onStart, handler.SubscriptionOnStart) + } + r.modules = append(r.modules, moduleInstance) r.logger.Info("Module registered", diff --git a/router/core/router_config.go b/router/core/router_config.go index 89d99f2ce1..ac4f26d4c7 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -25,6 +25,10 @@ import ( "go.uber.org/zap" ) +type subscriptionHooks struct { + onStart []func(ctx SubscriptionOnStartHookContext) error +} + type Config struct { clusterName string instanceID string @@ -118,6 +122,7 @@ type Config struct { mcp config.MCPConfiguration plugins config.PluginsConfiguration tracingAttributes []config.CustomAttribute + subscriptionHooks subscriptionHooks } // Usage returns an anonymized version of the config for usage tracking diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go new file mode 100644 index 0000000000..505bbfc44f --- /dev/null +++ b/router/core/subscriptions_modules.go @@ -0,0 +1,188 @@ +package core + +import ( + "net/http" + + "github.com/wundergraph/cosmo/router/pkg/authentication" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" +) + +// StreamHookError is used to customize the error messages and the behavior +type StreamHookError struct { + err error + message string + statusCode int + code string +} + +func (e *StreamHookError) Error() string { + if e.err != nil { + return e.err.Error() + } + return e.message +} + +func (e *StreamHookError) Message() string { + return e.message +} + +func (e *StreamHookError) StatusCode() int { + return e.statusCode +} + +func (e *StreamHookError) Code() string { + return e.code +} + +func NewStreamHookError(err error, message string, statusCode int, code string) *StreamHookError { + return &StreamHookError{ + err: err, + message: message, + statusCode: statusCode, + code: code, + } +} + +type SubscriptionOnStartHookContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // WriteEvent writes an event to the stream of the current subscription + // It returns true if the event was written to the stream, false if the event was dropped + WriteEvent(event datasource.StreamEvent) bool +} + +type pubSubSubscriptionOnStartHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + subscriptionEventConfiguration datasource.SubscriptionEventConfiguration + writeEventHook func(data []byte) +} + +func (c *pubSubSubscriptionOnStartHookContext) Request() *http.Request { + return c.request +} + +func (c *pubSubSubscriptionOnStartHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *pubSubSubscriptionOnStartHookContext) Operation() OperationContext { + return c.operation +} + +func (c *pubSubSubscriptionOnStartHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { + return c.subscriptionEventConfiguration +} + +func (c *pubSubSubscriptionOnStartHookContext) WriteEvent(event datasource.StreamEvent) bool { + c.writeEventHook(event.GetData()) + + return true +} + +// EngineEvent is the event used to write to the engine subscription +type EngineEvent struct { + Data []byte +} + +func (e *EngineEvent) GetData() []byte { + return e.Data +} + +type engineSubscriptionOnStartHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + writeEventHook func(data []byte) +} + +func (c *engineSubscriptionOnStartHookContext) Request() *http.Request { + return c.request +} + +func (c *engineSubscriptionOnStartHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *engineSubscriptionOnStartHookContext) Operation() OperationContext { + return c.operation +} + +func (c *engineSubscriptionOnStartHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *engineSubscriptionOnStartHookContext) WriteEvent(event datasource.StreamEvent) bool { + c.writeEventHook(event.GetData()) + + return true +} + +func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { + return nil +} + +type SubscriptionOnStartHandler interface { + // SubscriptionOnStart is called once at subscription start + // The error is propagated to the client. + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// NewPubSubSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a pubsub.SubscriptionOnStartFn +func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) datasource.SubscriptionOnStartFn { + if fn == nil { + return nil + } + + return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration) error { + requestContext := getRequestContext(resolveCtx.Context) + hookCtx := &pubSubSubscriptionOnStartHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + subscriptionEventConfiguration: subConf, + writeEventHook: resolveCtx.Updater, + } + + return fn(hookCtx) + } +} + +// NewEngineSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a graphql_datasource.SubscriptionOnStartFn +func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) graphql_datasource.SubscriptionOnStartFn { + if fn == nil { + return nil + } + + return func(resolveCtx resolve.StartupHookContext, input []byte) error { + requestContext := getRequestContext(resolveCtx.Context) + hookCtx := &engineSubscriptionOnStartHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + writeEventHook: resolveCtx.Updater, + } + + return fn(hookCtx) + } +} diff --git a/router/go.mod b/router/go.mod index 180c9f51d0..82ff4f2e73 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 @@ -196,4 +196,4 @@ replace ( // Remember you can use Go workspaces to avoid using replace directives in multiple go.mod files // Use what is best for your personal workflow. See CONTRIBUTING.md for more information -// replace github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 +//replace github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 diff --git a/router/go.sum b/router/go.sum index 0263992f20..1a0bc0afe5 100644 --- a/router/go.sum +++ b/router/go.sum @@ -321,8 +321,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 h1:VCfCX/xmpBGQLhTHJMHLugzJrXJk/smjLRAEruCI0HY= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= diff --git a/router/pkg/pubsub/datasource/datasource.go b/router/pkg/pubsub/datasource/datasource.go index 2f08b97074..3a3018b745 100644 --- a/router/pkg/pubsub/datasource/datasource.go +++ b/router/pkg/pubsub/datasource/datasource.go @@ -1,9 +1,17 @@ package datasource import ( + "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) +type SubscriptionDataSource interface { + SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) + Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error + UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) + SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) +} + // EngineDataSourceFactory is the interface that all pubsub data sources must implement. // It serves three main purposes: // 1. Resolving the data source and subscription data source @@ -23,7 +31,7 @@ type EngineDataSourceFactory interface { // ResolveDataSourceSubscription returns the engine SubscriptionDataSource implementation // that contains methods to start a subscription, which will be called by the Planner // when a subscription is initiated - ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) + ResolveDataSourceSubscription() (SubscriptionDataSource, error) // ResolveDataSourceSubscriptionInput build the input that will be passed to the engine SubscriptionDataSource ResolveDataSourceSubscriptionInput() (string, error) // TransformEventData allows the data source to transform the event data using the extractFn diff --git a/router/pkg/pubsub/datasource/factory.go b/router/pkg/pubsub/datasource/factory.go index cbceb1a651..5c42161776 100644 --- a/router/pkg/pubsub/datasource/factory.go +++ b/router/pkg/pubsub/datasource/factory.go @@ -9,14 +9,16 @@ import ( ) type PlannerConfig[PB ProviderBuilder[P, E], P any, E any] struct { - ProviderBuilder PB - Event E + ProviderBuilder PB + Event E + SubscriptionOnStartFns []SubscriptionOnStartFn } -func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E) *PlannerConfig[PB, P, E] { +func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E, subscriptionOnStartFns []SubscriptionOnStartFn) *PlannerConfig[PB, P, E] { return &PlannerConfig[PB, P, E]{ - ProviderBuilder: providerBuilder, - Event: event, + ProviderBuilder: providerBuilder, + Event: event, + SubscriptionOnStartFns: subscriptionOnStartFns, } } diff --git a/router/pkg/pubsub/datasource/mocks.go b/router/pkg/pubsub/datasource/mocks.go index a6bbb19e18..861beb3987 100644 --- a/router/pkg/pubsub/datasource/mocks.go +++ b/router/pkg/pubsub/datasource/mocks.go @@ -198,23 +198,23 @@ func (_c *MockEngineDataSourceFactory_ResolveDataSourceInput_Call) RunAndReturn( } // ResolveDataSourceSubscription provides a mock function for the type MockEngineDataSourceFactory -func (_mock *MockEngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { +func (_mock *MockEngineDataSourceFactory) ResolveDataSourceSubscription() (SubscriptionDataSource, error) { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for ResolveDataSourceSubscription") } - var r0 resolve.SubscriptionDataSource + var r0 SubscriptionDataSource var r1 error - if returnFunc, ok := ret.Get(0).(func() (resolve.SubscriptionDataSource, error)); ok { + if returnFunc, ok := ret.Get(0).(func() (SubscriptionDataSource, error)); ok { return returnFunc() } - if returnFunc, ok := ret.Get(0).(func() resolve.SubscriptionDataSource); ok { + if returnFunc, ok := ret.Get(0).(func() SubscriptionDataSource); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(resolve.SubscriptionDataSource) + r0 = ret.Get(0).(SubscriptionDataSource) } } if returnFunc, ok := ret.Get(1).(func() error); ok { @@ -242,12 +242,12 @@ func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) Run(ru return _c } -func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) Return(subscriptionDataSource resolve.SubscriptionDataSource, err error) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { +func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) Return(subscriptionDataSource SubscriptionDataSource, err error) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { _c.Call.Return(subscriptionDataSource, err) return _c } -func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) RunAndReturn(run func() (resolve.SubscriptionDataSource, error)) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { +func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) RunAndReturn(run func() (SubscriptionDataSource, error)) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { _c.Call.Return(run) return _c } @@ -356,13 +356,13 @@ func (_c *MockEngineDataSourceFactory_TransformEventData_Call) RunAndReturn(run return _c } -// NewMockProviderLifecycle creates a new instance of MockProviderLifecycle. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewMockLifecycle creates a new instance of MockLifecycle. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewMockProviderLifecycle(t interface { +func NewMockLifecycle(t interface { mock.TestingT Cleanup(func()) -}) *MockProviderLifecycle { - mock := &MockProviderLifecycle{} +}) *MockLifecycle { + mock := &MockLifecycle{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) @@ -370,21 +370,21 @@ func NewMockProviderLifecycle(t interface { return mock } -// MockProviderLifecycle is an autogenerated mock type for the ProviderLifecycle type -type MockProviderLifecycle struct { +// MockLifecycle is an autogenerated mock type for the Lifecycle type +type MockLifecycle struct { mock.Mock } -type MockProviderLifecycle_Expecter struct { +type MockLifecycle_Expecter struct { mock *mock.Mock } -func (_m *MockProviderLifecycle) EXPECT() *MockProviderLifecycle_Expecter { - return &MockProviderLifecycle_Expecter{mock: &_m.Mock} +func (_m *MockLifecycle) EXPECT() *MockLifecycle_Expecter { + return &MockLifecycle_Expecter{mock: &_m.Mock} } -// Shutdown provides a mock function for the type MockProviderLifecycle -func (_mock *MockProviderLifecycle) Shutdown(ctx context.Context) error { +// Shutdown provides a mock function for the type MockLifecycle +func (_mock *MockLifecycle) Shutdown(ctx context.Context) error { ret := _mock.Called(ctx) if len(ret) == 0 { @@ -400,18 +400,18 @@ func (_mock *MockProviderLifecycle) Shutdown(ctx context.Context) error { return r0 } -// MockProviderLifecycle_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' -type MockProviderLifecycle_Shutdown_Call struct { +// MockLifecycle_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' +type MockLifecycle_Shutdown_Call struct { *mock.Call } // Shutdown is a helper method to define mock.On call // - ctx context.Context -func (_e *MockProviderLifecycle_Expecter) Shutdown(ctx interface{}) *MockProviderLifecycle_Shutdown_Call { - return &MockProviderLifecycle_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} +func (_e *MockLifecycle_Expecter) Shutdown(ctx interface{}) *MockLifecycle_Shutdown_Call { + return &MockLifecycle_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} } -func (_c *MockProviderLifecycle_Shutdown_Call) Run(run func(ctx context.Context)) *MockProviderLifecycle_Shutdown_Call { +func (_c *MockLifecycle_Shutdown_Call) Run(run func(ctx context.Context)) *MockLifecycle_Shutdown_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -424,18 +424,18 @@ func (_c *MockProviderLifecycle_Shutdown_Call) Run(run func(ctx context.Context) return _c } -func (_c *MockProviderLifecycle_Shutdown_Call) Return(err error) *MockProviderLifecycle_Shutdown_Call { +func (_c *MockLifecycle_Shutdown_Call) Return(err error) *MockLifecycle_Shutdown_Call { _c.Call.Return(err) return _c } -func (_c *MockProviderLifecycle_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockProviderLifecycle_Shutdown_Call { +func (_c *MockLifecycle_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockLifecycle_Shutdown_Call { _c.Call.Return(run) return _c } -// Startup provides a mock function for the type MockProviderLifecycle -func (_mock *MockProviderLifecycle) Startup(ctx context.Context) error { +// Startup provides a mock function for the type MockLifecycle +func (_mock *MockLifecycle) Startup(ctx context.Context) error { ret := _mock.Called(ctx) if len(ret) == 0 { @@ -451,18 +451,18 @@ func (_mock *MockProviderLifecycle) Startup(ctx context.Context) error { return r0 } -// MockProviderLifecycle_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' -type MockProviderLifecycle_Startup_Call struct { +// MockLifecycle_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' +type MockLifecycle_Startup_Call struct { *mock.Call } // Startup is a helper method to define mock.On call // - ctx context.Context -func (_e *MockProviderLifecycle_Expecter) Startup(ctx interface{}) *MockProviderLifecycle_Startup_Call { - return &MockProviderLifecycle_Startup_Call{Call: _e.mock.On("Startup", ctx)} +func (_e *MockLifecycle_Expecter) Startup(ctx interface{}) *MockLifecycle_Startup_Call { + return &MockLifecycle_Startup_Call{Call: _e.mock.On("Startup", ctx)} } -func (_c *MockProviderLifecycle_Startup_Call) Run(run func(ctx context.Context)) *MockProviderLifecycle_Startup_Call { +func (_c *MockLifecycle_Startup_Call) Run(run func(ctx context.Context)) *MockLifecycle_Startup_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -475,12 +475,12 @@ func (_c *MockProviderLifecycle_Startup_Call) Run(run func(ctx context.Context)) return _c } -func (_c *MockProviderLifecycle_Startup_Call) Return(err error) *MockProviderLifecycle_Startup_Call { +func (_c *MockLifecycle_Startup_Call) Return(err error) *MockLifecycle_Startup_Call { _c.Call.Return(err) return _c } -func (_c *MockProviderLifecycle_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockProviderLifecycle_Startup_Call { +func (_c *MockLifecycle_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockLifecycle_Startup_Call { _c.Call.Return(run) return _c } @@ -658,6 +658,69 @@ func (_c *MockProvider_Startup_Call) RunAndReturn(run func(ctx context.Context) return _c } +// Subscribe provides a mock function for the type MockProvider +func (_mock *MockProvider) Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error { + ret := _mock.Called(ctx, cfg, updater) + + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, SubscriptionEventUpdater) error); ok { + r0 = returnFunc(ctx, cfg, updater) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockProvider_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' +type MockProvider_Subscribe_Call struct { + *mock.Call +} + +// Subscribe is a helper method to define mock.On call +// - ctx context.Context +// - cfg SubscriptionEventConfiguration +// - updater SubscriptionEventUpdater +func (_e *MockProvider_Expecter) Subscribe(ctx interface{}, cfg interface{}, updater interface{}) *MockProvider_Subscribe_Call { + return &MockProvider_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, cfg, updater)} +} + +func (_c *MockProvider_Subscribe_Call) Run(run func(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater)) *MockProvider_Subscribe_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 SubscriptionEventConfiguration + if args[1] != nil { + arg1 = args[1].(SubscriptionEventConfiguration) + } + var arg2 SubscriptionEventUpdater + if args[2] != nil { + arg2 = args[2].(SubscriptionEventUpdater) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockProvider_Subscribe_Call) Return(err error) *MockProvider_Subscribe_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockProvider_Subscribe_Call) RunAndReturn(run func(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error) *MockProvider_Subscribe_Call { + _c.Call.Return(run) + return _c +} + // TypeID provides a mock function for the type MockProvider func (_mock *MockProvider) TypeID() string { ret := _mock.Called() @@ -896,3 +959,143 @@ func (_c *MockProviderBuilder_TypeID_Call[P, E]) RunAndReturn(run func() string) _c.Call.Return(run) return _c } + +// NewMockSubscriptionEventUpdater creates a new instance of MockSubscriptionEventUpdater. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSubscriptionEventUpdater(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSubscriptionEventUpdater { + mock := &MockSubscriptionEventUpdater{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockSubscriptionEventUpdater is an autogenerated mock type for the SubscriptionEventUpdater type +type MockSubscriptionEventUpdater struct { + mock.Mock +} + +type MockSubscriptionEventUpdater_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSubscriptionEventUpdater) EXPECT() *MockSubscriptionEventUpdater_Expecter { + return &MockSubscriptionEventUpdater_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { + _mock.Called(kind) + return +} + +// MockSubscriptionEventUpdater_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockSubscriptionEventUpdater_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +// - kind resolve.SubscriptionCloseKind +func (_e *MockSubscriptionEventUpdater_Expecter) Close(kind interface{}) *MockSubscriptionEventUpdater_Close_Call { + return &MockSubscriptionEventUpdater_Close_Call{Call: _e.mock.On("Close", kind)} +} + +func (_c *MockSubscriptionEventUpdater_Close_Call) Run(run func(kind resolve.SubscriptionCloseKind)) *MockSubscriptionEventUpdater_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 resolve.SubscriptionCloseKind + if args[0] != nil { + arg0 = args[0].(resolve.SubscriptionCloseKind) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_Close_Call) Return() *MockSubscriptionEventUpdater_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_Close_Call) RunAndReturn(run func(kind resolve.SubscriptionCloseKind)) *MockSubscriptionEventUpdater_Close_Call { + _c.Run(run) + return _c +} + +// Complete provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) Complete() { + _mock.Called() + return +} + +// MockSubscriptionEventUpdater_Complete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Complete' +type MockSubscriptionEventUpdater_Complete_Call struct { + *mock.Call +} + +// Complete is a helper method to define mock.On call +func (_e *MockSubscriptionEventUpdater_Expecter) Complete() *MockSubscriptionEventUpdater_Complete_Call { + return &MockSubscriptionEventUpdater_Complete_Call{Call: _e.mock.On("Complete")} +} + +func (_c *MockSubscriptionEventUpdater_Complete_Call) Run(run func()) *MockSubscriptionEventUpdater_Complete_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_Complete_Call) Return() *MockSubscriptionEventUpdater_Complete_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_Complete_Call) RunAndReturn(run func()) *MockSubscriptionEventUpdater_Complete_Call { + _c.Run(run) + return _c +} + +// Update provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) Update(event StreamEvent) { + _mock.Called(event) + return +} + +// MockSubscriptionEventUpdater_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockSubscriptionEventUpdater_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - event StreamEvent +func (_e *MockSubscriptionEventUpdater_Expecter) Update(event interface{}) *MockSubscriptionEventUpdater_Update_Call { + return &MockSubscriptionEventUpdater_Update_Call{Call: _e.mock.On("Update", event)} +} + +func (_c *MockSubscriptionEventUpdater_Update_Call) Run(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 StreamEvent + if args[0] != nil { + arg0 = args[0].(StreamEvent) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_Update_Call) Return() *MockSubscriptionEventUpdater_Update_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_Update_Call) RunAndReturn(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { + _c.Run(run) + return _c +} diff --git a/router/pkg/pubsub/datasource/planner.go b/router/pkg/pubsub/datasource/planner.go index e3b54a92ec..a480f8270e 100644 --- a/router/pkg/pubsub/datasource/planner.go +++ b/router/pkg/pubsub/datasource/planner.go @@ -109,6 +109,7 @@ func (p *Planner[PB, P, E]) ConfigureSubscription() plan.SubscriptionConfigurati p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription: %w", err)) return plan.SubscriptionConfiguration{} } + dataSource.SetSubscriptionOnStartFns(p.config.SubscriptionOnStartFns...) input, err := pubSubDataSource.ResolveDataSourceSubscriptionInput() if err != nil { diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index d9138630ca..33cac33782 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -4,20 +4,30 @@ import ( "context" "github.com/wundergraph/cosmo/router/pkg/metric" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) type ArgumentTemplateCallback func(tpl string) (string, error) -type ProviderLifecycle interface { +// Lifecycle is the interface that the provider must implement +// to allow the router to start and stop the provider +type Lifecycle interface { // Startup is the method called when the provider is started Startup(ctx context.Context) error // Shutdown is the method called when the provider is shut down Shutdown(ctx context.Context) error } +// Adapter is the interface that the provider must implement +// to implement the basic functionality +type Adapter interface { + Lifecycle + Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error +} + // Provider is the interface that the PubSub provider must implement type Provider interface { - ProviderLifecycle + Adapter // ID Get the provider ID as specified in the configuration ID() string // TypeID Get the provider type id (e.g. "kafka", "nats") @@ -34,6 +44,38 @@ type ProviderBuilder[P, E any] interface { BuildEngineDataSourceFactory(data E) (EngineDataSourceFactory, error) } +// ProviderType represents the type of pubsub provider +type ProviderType string + +const ( + ProviderTypeNats ProviderType = "nats" + ProviderTypeKafka ProviderType = "kafka" + ProviderTypeRedis ProviderType = "redis" +) + +// StreamEvent is a generic interface for all stream events +// Each provider will have its own event type that implements this interface +// there could be other common fields in the future, but for now we only have data +type StreamEvent interface { + GetData() []byte +} + +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration) error + +// SubscriptionEventConfiguration is the interface that all subscription event configurations must implement +type SubscriptionEventConfiguration interface { + ProviderID() string + ProviderType() ProviderType + RootFieldName() string // the root field name of the subscription in the schema +} + +// PublishEventConfiguration is the interface that all publish event configurations must implement +type PublishEventConfiguration interface { + ProviderID() string + ProviderType() ProviderType + RootFieldName() string // the root field name of the mutation in the schema +} + type ProviderOpts struct { StreamMetricStore metric.StreamMetricStore } diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 9e1223d950..84561b06db 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -9,7 +9,7 @@ import ( type PubSubProvider struct { id string typeID string - Adapter ProviderLifecycle + Adapter Adapter Logger *zap.Logger } @@ -35,7 +35,11 @@ func (p *PubSubProvider) Shutdown(ctx context.Context) error { return nil } -func NewPubSubProvider(id string, typeID string, adapter ProviderLifecycle, logger *zap.Logger) *PubSubProvider { +func (p *PubSubProvider) Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error { + return p.Adapter.Subscribe(ctx, cfg, updater) +} + +func NewPubSubProvider(id string, typeID string, adapter Adapter, logger *zap.Logger) *PubSubProvider { return &PubSubProvider{ id: id, typeID: typeID, diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 6579b62072..134bfbd6bb 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -10,7 +10,7 @@ import ( ) func TestProvider_Startup_Success(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Startup", mock.Anything).Return(nil) provider := PubSubProvider{ @@ -22,7 +22,7 @@ func TestProvider_Startup_Success(t *testing.T) { } func TestProvider_Startup_Error(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Startup", mock.Anything).Return(errors.New("connect error")) provider := PubSubProvider{ @@ -34,7 +34,7 @@ func TestProvider_Startup_Error(t *testing.T) { } func TestProvider_Shutdown_Success(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Shutdown", mock.Anything).Return(nil) provider := PubSubProvider{ @@ -46,7 +46,7 @@ func TestProvider_Shutdown_Success(t *testing.T) { } func TestProvider_Shutdown_Error(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Shutdown", mock.Anything).Return(errors.New("close error")) provider := PubSubProvider{ diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go new file mode 100644 index 0000000000..e5c9c26ab6 --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -0,0 +1,72 @@ +package datasource + +import ( + "encoding/json" + "errors" + + "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error + +// PubSubSubscriptionDataSource is a data source for handling subscriptions using a Pub/Sub mechanism. +// It implements the SubscriptionDataSource interface and HookableSubscriptionDataSource +type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { + pubSub Adapter + uniqueRequestID uniqueRequestIdFn + subscriptionOnStartFns []SubscriptionOnStartFn +} + +func (s *PubSubSubscriptionDataSource[C]) SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) { + var subscriptionConfiguration C + err := json.Unmarshal(input, &subscriptionConfiguration) + return subscriptionConfiguration, err +} + +func (s *PubSubSubscriptionDataSource[C]) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return s.uniqueRequestID(ctx, input, xxh) +} + +func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { + subConf, err := s.SubscriptionEventConfiguration(input) + if err != nil { + return err + } + + conf, ok := subConf.(C) + if !ok { + return errors.New("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(updater)) +} + +func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { + for _, fn := range s.subscriptionOnStartFns { + conf, errConf := s.SubscriptionEventConfiguration(input) + if errConf != nil { + return err + } + err = fn(ctx, conf) + if err != nil { + return err + } + } + + return nil +} + +func (s *PubSubSubscriptionDataSource[C]) SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) { + s.subscriptionOnStartFns = append(s.subscriptionOnStartFns, fns...) +} + +var _ SubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) +var _ resolve.HookableSubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) + +func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn) *PubSubSubscriptionDataSource[C] { + return &PubSubSubscriptionDataSource[C]{ + pubSub: pubSub, + uniqueRequestID: uniqueRequestIdFn, + } +} diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go new file mode 100644 index 0000000000..a9170d7edd --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -0,0 +1,327 @@ +package datasource + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +// testSubscriptionEventConfiguration implements SubscriptionEventConfiguration for testing +type testSubscriptionEventConfiguration struct { + Topic string `json:"topic"` + Subject string `json:"subject"` +} + +func (t testSubscriptionEventConfiguration) ProviderID() string { + return "test-provider" +} + +func (t testSubscriptionEventConfiguration) ProviderType() ProviderType { + return ProviderTypeNats +} + +func (t testSubscriptionEventConfiguration) RootFieldName() string { + return "testSubscription" +} + +func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + result, err := dataSource.SubscriptionEventConfiguration(input) + assert.NoError(t, err) + assert.NotNil(t, result) + + typedResult, ok := result.(testSubscriptionEventConfiguration) + assert.True(t, ok) + assert.Equal(t, "test-topic", typedResult.Topic) + assert.Equal(t, "test-subject", typedResult.Subject) +} + +func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_InvalidJSON(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + invalidInput := []byte(`{"invalid": json}`) + result, err := dataSource.SubscriptionEventConfiguration(invalidInput) + assert.Error(t, err) + assert.Equal(t, testSubscriptionEventConfiguration{}, result) +} + +func TestPubSubSubscriptionDataSource_UniqueRequestID_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + ctx := &resolve.Context{} + input := []byte(`{"test": "data"}`) + xxh := xxhash.New() + + err := dataSource.UniqueRequestID(ctx, input, xxh) + assert.NoError(t, err) +} + +func TestPubSubSubscriptionDataSource_UniqueRequestID_Error(t *testing.T) { + mockAdapter := NewMockProvider(t) + expectedError := errors.New("unique ID generation error") + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return expectedError + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + ctx := &resolve.Context{} + input := []byte(`{"test": "data"}`) + xxh := xxhash.New() + + err := dataSource.UniqueRequestID(ctx, input, xxh) + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestPubSubSubscriptionDataSource_Start_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.NewContext(context.Background()) + mockUpdater := NewMockSubscriptionUpdater(t) + + mockAdapter.On("Subscribe", ctx.Context(), testConfig, mock.AnythingOfType("*datasource.subscriptionEventUpdater")).Return(nil) + + err = dataSource.Start(ctx, input, mockUpdater) + assert.NoError(t, err) + mockAdapter.AssertExpectations(t) +} + +func TestPubSubSubscriptionDataSource_Start_NoConfiguration(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + invalidInput := []byte(`{"invalid": json}`) + ctx := resolve.NewContext(context.Background()) + mockUpdater := NewMockSubscriptionUpdater(t) + + err := dataSource.Start(ctx, invalidInput, mockUpdater) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid character 'j' looking for beginning of value") +} + +func TestPubSubSubscriptionDataSource_Start_SubscribeError(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.NewContext(context.Background()) + mockUpdater := NewMockSubscriptionUpdater(t) + expectedError := errors.New("subscription error") + + mockAdapter.On("Subscribe", ctx.Context(), testConfig, mock.AnythingOfType("*datasource.subscriptionEventUpdater")).Return(expectedError) + + err = dataSource.Start(ctx, input, mockUpdater) + assert.Error(t, err) + assert.Equal(t, expectedError, err) + mockAdapter.AssertExpectations(t) +} + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(ctx, input) + assert.NoError(t, err) +} + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + // Add subscription start hooks + hook1Called := false + hook2Called := false + + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook1Called = true + return nil + } + + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook2Called = true + return nil + } + + dataSource.SetSubscriptionOnStartFns(hook1, hook2) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(ctx, input) + assert.NoError(t, err) + assert.True(t, hook1Called) + assert.True(t, hook2Called) +} + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + expectedError := errors.New("hook error") + // Add hook that returns an error + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return expectedError + } + + dataSource.SetSubscriptionOnStartFns(hook) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(ctx, input) + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + // Initially should have no hooks + assert.Len(t, dataSource.subscriptionOnStartFns, 0) + + // Add hooks + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return nil + } + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return nil + } + + dataSource.SetSubscriptionOnStartFns(hook1) + assert.Len(t, dataSource.subscriptionOnStartFns, 1) + + dataSource.SetSubscriptionOnStartFns(hook2) + assert.Len(t, dataSource.subscriptionOnStartFns, 2) +} + +func TestNewPubSubSubscriptionDataSource(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + assert.NotNil(t, dataSource) + assert.Equal(t, mockAdapter, dataSource.pubSub) + assert.NotNil(t, dataSource.uniqueRequestID) + assert.Empty(t, dataSource.subscriptionOnStartFns) +} + +func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + // Test that it implements SubscriptionDataSource interface + var _ SubscriptionDataSource = dataSource + + // Test that it implements HookableSubscriptionDataSource interface + var _ resolve.HookableSubscriptionDataSource = dataSource +} diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go new file mode 100644 index 0000000000..9332d10f7a --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -0,0 +1,34 @@ +package datasource + +import "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + +// SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface +// that provides a way to send the event struct instead of the raw data +// It is used to give access to the event additional fields to the hooks. +type SubscriptionEventUpdater interface { + Update(event StreamEvent) + Complete() + Close(kind resolve.SubscriptionCloseKind) +} + +type subscriptionEventUpdater struct { + eventUpdater resolve.SubscriptionUpdater +} + +func (h *subscriptionEventUpdater) Update(event StreamEvent) { + h.eventUpdater.Update(event.GetData()) +} + +func (h *subscriptionEventUpdater) Complete() { + h.eventUpdater.Complete() +} + +func (h *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { + h.eventUpdater.Close(kind) +} + +func NewSubscriptionEventUpdater(eventUpdater resolve.SubscriptionUpdater) SubscriptionEventUpdater { + return &subscriptionEventUpdater{ + eventUpdater: eventUpdater, + } +} diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index e11993b668..fa906370ab 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -13,7 +13,6 @@ import ( "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kgo" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -28,7 +27,7 @@ const ( // Adapter defines the interface for Kafka adapter operations type Adapter interface { - Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error Publish(ctx context.Context, event PublishEventConfiguration) error Startup(ctx context.Context) error Shutdown(ctx context.Context) error @@ -54,7 +53,7 @@ type PollerOpts struct { } // topicPoller polls the Kafka topic for new records and calls the updateTriggers function. -func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, updater resolve.SubscriptionUpdater, pollerOpts PollerOpts) error { +func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, updater datasource.SubscriptionEventUpdater, pollerOpts PollerOpts) error { for { select { case <-p.ctx.Done(): // Close the poller if the application context was canceled @@ -100,13 +99,25 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u r := iter.Next() p.logger.Debug("subscription update", zap.String("topic", r.Topic), zap.ByteString("data", r.Value)) + + headers := make(map[string][]byte) + for _, header := range r.Headers { + headers[header.Key] = header.Value + } + p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ ProviderId: pollerOpts.providerId, StreamOperationName: kafkaReceive, ProviderType: metric.ProviderTypeKafka, DestinationName: r.Topic, }) - updater.Update(r.Value) + + updater.Update(&Event{ + Data: r.Value, + Headers: headers, + Key: r.Key, + }) + } } } @@ -114,23 +125,27 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u // Subscribe subscribes to the given topics and updates the subscription updater. // The engine already deduplicates subscriptions with the same topics, stream configuration, extensions, headers, etc. -func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := conf.(*SubscriptionEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), - zap.Strings("topics", event.Topics), + zap.Strings("topics", subConf.Topics), ) // Create a new client for the topic client, err := kgo.NewClient(append(p.opts, - kgo.ConsumeTopics(event.Topics...), + kgo.ConsumeTopics(subConf.Topics...), // We want to consume the events produced after the first subscription was created // Messages are shared among all subscriptions, therefore old events are not redelivered // This replicates a stateless publish-subscribe model kgo.ConsumeResetOffset(kgo.NewOffset().AfterMilli(time.Now().UnixMilli())), // For observability, we set the client ID to "router" - kgo.ClientID(fmt.Sprintf("cosmo.router.consumer.%s", strings.Join(event.Topics, "-"))), + kgo.ClientID(fmt.Sprintf("cosmo.router.consumer.%s", strings.Join(subConf.Topics, "-"))), // FIXME: the client id should have some unique identifier, like in nats // What if we have multiple subscriptions for the same topics? // What if we have more router instances? @@ -146,7 +161,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent defer p.closeWg.Done() - err := p.topicPoller(ctx, client, updater, PollerOpts{providerId: event.ProviderID}) + err := p.topicPoller(ctx, client, updater, PollerOpts{providerId: conf.ProviderID()}) if err != nil { if errors.Is(err, errClientClosed) || errors.Is(err, context.Canceled) { log.Debug("poller canceled", zap.Error(err)) @@ -166,7 +181,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent // The event is written with a dedicated write client. func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "publish"), zap.String("topic", event.Topic), ) @@ -175,16 +190,26 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu return datasource.NewError("kafka write client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Data)) + log.Debug("publish", zap.ByteString("data", event.Event.Data)) var wg sync.WaitGroup wg.Add(1) var pErr error + headers := make([]kgo.RecordHeader, 0, len(event.Event.Headers)) + for key, value := range event.Event.Headers { + headers = append(headers, kgo.RecordHeader{ + Key: key, + Value: value, + }) + } + p.writeClient.Produce(ctx, &kgo.Record{ - Topic: event.Topic, - Value: event.Data, + Key: event.Event.Key, + Topic: event.Topic, + Value: event.Event.Data, + Headers: headers, }, func(record *kgo.Record, err error) { defer wg.Done() if err != nil { @@ -198,7 +223,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu log.Error("publish error", zap.Error(pErr)) // failure emission: include error.type generic p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, ErrorType: "publish_error", @@ -208,7 +233,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, DestinationName: event.Topic, diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 7b82a766b0..723c0d0bd0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -7,61 +7,78 @@ import ( "fmt" "io" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) +// Event represents an event from Kafka +type Event struct { + Key []byte `json:"key"` + Data json.RawMessage `json:"data"` + Headers map[string][]byte `json:"headers"` +} + +func (e *Event) GetData() []byte { + return e.Data +} + type SubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` - Topics []string `json:"topics"` + Provider string `json:"providerId"` + Topics []string `json:"topics"` + FieldName string `json:"rootFieldName"` } -type PublishEventConfiguration struct { - ProviderID string `json:"providerId"` - Topic string `json:"topic"` - Data json.RawMessage `json:"data"` +// ProviderID returns the provider ID +func (s *SubscriptionEventConfiguration) ProviderID() string { + return s.Provider } -func (s *PublishEventConfiguration) MarshalJSONTemplate() string { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - return fmt.Sprintf(`{"topic":"%s", "data": %s, "providerId":"%s"}`, s.Topic, s.Data, s.ProviderID) +// ProviderType returns the provider type +func (s *SubscriptionEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeKafka } -type SubscriptionDataSource struct { - pubSub Adapter +// RootFieldName returns the root field name +func (s *SubscriptionEventConfiguration) RootFieldName() string { + return s.FieldName } -func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } +type PublishEventConfiguration struct { + Provider string `json:"providerId"` + Topic string `json:"topic"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} - _, err = xxh.Write(val) - if err != nil { - return err - } +// ProviderID returns the provider ID +func (p *PublishEventConfiguration) ProviderID() string { + return p.Provider +} - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } +// ProviderType returns the provider type +func (p *PublishEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeKafka +} - _, err = xxh.Write(val) - return err +// RootFieldName returns the root field name +func (p *PublishEventConfiguration) RootFieldName() string { + return p.FieldName } -func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration SubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) +func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + headers := s.Event.Headers + if headers == nil { + headers = make(map[string][]byte) + } + + headersBytes, err := json.Marshal(headers) if err != nil { - return err + return "", err } - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) + return fmt.Sprintf(`{"topic":"%s", "event": {"data": %s, "key": "%s", "headers": %s}, "providerId":"%s"}`, s.Topic, s.Event.Data, s.Event.Key, headersBytes, s.ProviderID()), nil } type PublishDataSource struct { @@ -70,8 +87,7 @@ type PublishDataSource struct { func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration PublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { + if err := json.Unmarshal(input, &publishConfiguration); err != nil { return err } @@ -79,10 +95,15 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(out, `{"success": true}`) + _, err := io.WriteString(out, `{"success": true}`) return err } func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { panic("not implemented") } + +// Interface compliance checks +var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) +var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) +var _ datasource.StreamEvent = (*Event)(nil) diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory.go b/router/pkg/pubsub/kafka/engine_datasource_factory.go index d360f02f26..30507bc13b 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -49,24 +51,44 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri } evtCfg := PublishEventConfiguration{ - ProviderID: c.providerId, - Topic: c.topics[0], - Data: eventData, + Provider: c.providerId, + Topic: c.topics[0], + Event: Event{Data: eventData}, + FieldName: c.fieldName, } - return evtCfg.MarshalJSONTemplate(), nil + return evtCfg.MarshalJSONTemplate() } -func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { - return &SubscriptionDataSource{ - pubSub: c.KafkaAdapter, - }, nil +func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.KafkaAdapter, + func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "topics") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err + }), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { evtCfg := SubscriptionEventConfiguration{ - ProviderID: c.providerId, - Topics: c.topics, + Provider: c.providerId, + Topics: c.topics, + FieldName: c.fieldName, } object, err := json.Marshal(evtCfg) if err != nil { diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go index 254359a4bc..0b4ea9c59c 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go @@ -4,11 +4,15 @@ import ( "bytes" "context" "encoding/json" + "errors" "testing" + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestKafkaEngineDataSourceFactory(t *testing.T) { @@ -33,7 +37,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Configure mock expectations for Publish mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Topic == "test-topic" + return event.ProviderID() == "test-provider" && event.Topic == "test-topic" })).Return(nil) // Create the data source with mock adapter @@ -137,3 +141,57 @@ func TestKafkaEngineDataSourceFactoryMultiTopicSubscription(t *testing.T) { require.Equal(t, "test-topic-1", subscriptionConfig.Topics[0], "Expected first topic to be 'test-topic-1'") require.Equal(t, "test-topic-2", subscriptionConfig.Topics[1], "Expected second topic to be 'test-topic-2'") } + +func TestKafkaEngineDataSourceFactory_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing topics", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"topics":["topic1", "topic2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := &EngineDataSourceFactory{ + KafkaAdapter: NewMockAdapter(t), + } + source, err := factory.ResolveDataSourceSubscription() + require.NoError(t, err) + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err = source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} diff --git a/router/pkg/pubsub/kafka/engine_datasource_test.go b/router/pkg/pubsub/kafka/engine_datasource_test.go index 0ad92aeb20..eed485b246 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_test.go @@ -7,12 +7,9 @@ import ( "errors" "testing" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { @@ -24,145 +21,46 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: PublishEventConfiguration{ - ProviderID: "test-provider", - Topic: "test-topic", - Data: json.RawMessage(`{"message":"hello"}`), + Provider: "test-provider", + Topic: "test-topic", + Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, }, - wantPattern: `{"topic":"test-topic", "data": {"message":"hello"}, "providerId":"test-provider"}`, + wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider"}`, }, { name: "with special characters", config: PublishEventConfiguration{ - ProviderID: "test-provider-id", - Topic: "topic-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, }, - wantPattern: `{"topic":"topic-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id"}`, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.config.MarshalJSONTemplate() - assert.Equal(t, tt.wantPattern, result) - }) - } -} - -func TestSubscriptionSource_UniqueRequestID(t *testing.T) { - tests := []struct { - name string - input string - expectError bool - expectedError error - }{ - { - name: "valid input", - input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, - expectError: false, - }, - { - name: "missing topics", - input: `{"providerId":"test-provider"}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - { - name: "missing providerId", - input: `{"topics":["topic1", "topic2"]}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - source := &SubscriptionDataSource{ - pubSub: NewMockAdapter(t), - } - ctx := &resolve.Context{} - input := []byte(tt.input) - xxh := xxhash.New() - - err := source.UniqueRequestID(ctx, input, xxh) - - if tt.expectError { - require.Error(t, err) - if tt.expectedError != nil { - // For jsonparser errors, just check if the error message contains the expected text - assert.Contains(t, err.Error(), tt.expectedError.Error()) - } - } else { - require.NoError(t, err) - // Check that the hash has been updated - assert.NotEqual(t, 0, xxh.Sum64()) - } - }) - } -} - -func TestSubscriptionSource_Start(t *testing.T) { - tests := []struct { - name string - input string - mockSetup func(*MockAdapter, *datasource.MockSubscriptionUpdater) - expectError bool - }{ { - name: "successful subscription", - input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Topics: []string{"topic1", "topic2"}, - }, mock.Anything).Return(nil) + name: "with key", + config: PublishEventConfiguration{ + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: Event{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, }, - expectError: false, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id"}`, }, { - name: "adapter returns error", - input: `{"topics":["topic1"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Topics: []string{"topic1"}, - }, mock.Anything).Return(errors.New("subscription error")) + name: "with headers", + config: PublishEventConfiguration{ + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: Event{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, }, - expectError: true, - }, - { - name: "invalid input json", - input: `{"invalid json":`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) {}, - expectError: true, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) - updater := datasource.NewMockSubscriptionUpdater(t) - tt.mockSetup(mockAdapter, updater) - - source := &SubscriptionDataSource{ - pubSub: mockAdapter, - } - - // Set up go context - goCtx := context.Background() - - // Create a resolve.Context with the standard context - resolveCtx := &resolve.Context{} - resolveCtx = resolveCtx.WithContext(goCtx) - - input := []byte(tt.input) - err := source.Start(resolveCtx, input, updater) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } + result, err := tt.config.MarshalJSONTemplate() + assert.NoError(t, err) + assert.Equal(t, tt.wantPattern, result) }) } } @@ -178,12 +76,12 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { }{ { name: "successful publish", - input: `{"topic":"test-topic", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Topic == "test-topic" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` })).Return(nil) }, expectError: false, @@ -192,7 +90,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { }, { name: "publish error", - input: `{"topic":"test-topic", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) }, diff --git a/router/pkg/pubsub/kafka/mocks.go b/router/pkg/pubsub/kafka/mocks.go index f39aee8b4e..08faa08eb2 100644 --- a/router/pkg/pubsub/kafka/mocks.go +++ b/router/pkg/pubsub/kafka/mocks.go @@ -8,7 +8,7 @@ import ( "context" mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) // NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. @@ -198,7 +198,7 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { ret := _mock.Called(ctx, event, updater) if len(ret) == 0 { @@ -206,7 +206,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, resolve.SubscriptionUpdater) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { r0 = returnFunc(ctx, event, updater) } else { r0 = ret.Error(0) @@ -221,25 +221,25 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event SubscriptionEventConfiguration -// - updater resolve.SubscriptionUpdater +// - event datasource.SubscriptionEventConfiguration +// - updater datasource.SubscriptionEventUpdater func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 SubscriptionEventConfiguration + var arg1 datasource.SubscriptionEventConfiguration if args[1] != nil { - arg1 = args[1].(SubscriptionEventConfiguration) + arg1 = args[1].(datasource.SubscriptionEventConfiguration) } - var arg2 resolve.SubscriptionUpdater + var arg2 datasource.SubscriptionEventUpdater if args[2] != nil { - arg2 = args[2].(resolve.SubscriptionUpdater) + arg2 = args[2].(datasource.SubscriptionEventUpdater) } run( arg0, @@ -255,7 +255,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c } diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index d10f8cf93d..dcba74a03b 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -14,7 +14,6 @@ import ( "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -27,7 +26,7 @@ const ( // Adapter defines the methods that a NATS adapter should implement type Adapter interface { // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error // Publish publishes the given event to the specified subject Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error // Request sends a request to the specified subject and writes the response to the given writer @@ -81,11 +80,15 @@ func (p *ProviderAdapter) getDurableConsumerName(durableName string, subjects [] return fmt.Sprintf("%s-%x", durableName, subjHash.Sum64()), nil } -func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := conf.(*SubscriptionEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), - zap.Strings("subjects", event.Subjects), + zap.Strings("subjects", subConf.Subjects), ) if p.client == nil { @@ -96,24 +99,24 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent return datasource.NewError("nats jetstream not initialized", nil) } - if event.StreamConfiguration != nil { - durableConsumerName, err := p.getDurableConsumerName(event.StreamConfiguration.Consumer, event.Subjects) + if subConf.StreamConfiguration != nil { + durableConsumerName, err := p.getDurableConsumerName(subConf.StreamConfiguration.Consumer, subConf.Subjects) if err != nil { return err } consumerConfig := jetstream.ConsumerConfig{ Durable: durableConsumerName, - FilterSubjects: event.Subjects, + FilterSubjects: subConf.Subjects, } // Durable consumers are removed automatically only if the InactiveThreshold value is set - if event.StreamConfiguration.ConsumerInactiveThreshold > 0 { - consumerConfig.InactiveThreshold = time.Duration(event.StreamConfiguration.ConsumerInactiveThreshold) * time.Second + if subConf.StreamConfiguration.ConsumerInactiveThreshold > 0 { + consumerConfig.InactiveThreshold = time.Duration(subConf.StreamConfiguration.ConsumerInactiveThreshold) * time.Second } - consumer, err := p.js.CreateOrUpdateConsumer(ctx, event.StreamConfiguration.StreamName, consumerConfig) + consumer, err := p.js.CreateOrUpdateConsumer(ctx, subConf.StreamConfiguration.StreamName, consumerConfig) if err != nil { log.Error("creating or updating consumer", zap.Error(err)) - return datasource.NewError(fmt.Sprintf(`failed to create or update consumer for stream "%s"`, event.StreamConfiguration.StreamName), err) + return datasource.NewError(fmt.Sprintf(`failed to create or update consumer for stream "%s"`, subConf.StreamConfiguration.StreamName), err) } p.closeWg.Add(1) @@ -142,12 +145,16 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent log.Debug("subscription update", zap.String("message_subject", msg.Subject()), zap.ByteString("data", msg.Data())) p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: conf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject(), }) - updater.Update(msg.Data()) + + updater.Update(&Event{ + Data: msg.Data(), + Headers: msg.Headers(), + }) // Acknowledge the message after it has been processed ackErr := msg.Ack() @@ -165,8 +172,8 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } msgChan := make(chan *nats.Msg) - subscriptions := make([]*nats.Subscription, len(event.Subjects)) - for i, subject := range event.Subjects { + subscriptions := make([]*nats.Subscription, len(subConf.Subjects)) + for i, subject := range subConf.Subjects { subscription, err := p.client.ChanSubscribe(subject, msgChan) if err != nil { log.Error("subscribing to NATS subject", zap.Error(err), zap.String("subscription_subject", subject)) @@ -184,13 +191,18 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent select { case msg := <-msgChan: log.Debug("subscription update", zap.String("message_subject", msg.Subject), zap.ByteString("data", msg.Data)) + p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: conf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject, }) - updater.Update(msg.Data) + + updater.Update(&Event{ + Data: msg.Data, + Headers: msg.Header, + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscriptions for _, subscription := range subscriptions { @@ -220,7 +232,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "publish"), zap.String("subject", event.Subject), ) @@ -229,13 +241,13 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEv return datasource.NewError("nats client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Data)) + log.Debug("publish", zap.ByteString("data", event.Event.Data)) - err := p.client.Publish(event.Subject, event.Data) + err := p.client.Publish(event.Subject, event.Event.Data) if err != nil { log.Error("publish error", zap.Error(err)) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsPublish, ProviderType: metric.ProviderTypeNats, ErrorType: "publish_error", @@ -244,7 +256,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEv return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", event.Subject), err) } else { p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsPublish, ProviderType: metric.ProviderTypeNats, DestinationName: event.Subject, @@ -256,7 +268,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEv func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "request"), zap.String("subject", event.Subject), ) @@ -265,13 +277,13 @@ func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEv return datasource.NewError("nats client not initialized", nil) } - log.Debug("request", zap.ByteString("data", event.Data)) + log.Debug("request", zap.ByteString("data", event.Event.Data)) - msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Data) + msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Event.Data) if err != nil { log.Error("request error", zap.Error(err)) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, ErrorType: "request_error", @@ -281,7 +293,7 @@ func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEv } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, DestinationName: event.Subject, diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index ffc23ca838..0fa41e5480 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -7,12 +7,20 @@ import ( "fmt" "io" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) +// Event represents an event from NATS +type Event struct { + Data json.RawMessage `json:"data"` + Headers map[string][]string `json:"headers"` +} + +func (e *Event) GetData() []byte { + return e.Data +} + type StreamConfiguration struct { Consumer string `json:"consumer"` ConsumerInactiveThreshold int32 `json:"consumerInactiveThreshold"` @@ -20,56 +28,53 @@ type StreamConfiguration struct { } type SubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` + Provider string `json:"providerId"` Subjects []string `json:"subjects"` StreamConfiguration *StreamConfiguration `json:"streamConfiguration,omitempty"` + FieldName string `json:"rootFieldName"` } -type PublishAndRequestEventConfiguration struct { - ProviderID string `json:"providerId"` - Subject string `json:"subject"` - Data json.RawMessage `json:"data"` +// ProviderID returns the provider ID +func (s *SubscriptionEventConfiguration) ProviderID() string { + return s.Provider } -func (s *PublishAndRequestEventConfiguration) MarshalJSONTemplate() string { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - return fmt.Sprintf(`{"subject":"%s", "data": %s, "providerId":"%s"}`, s.Subject, s.Data, s.ProviderID) +// ProviderType returns the provider type +func (s *SubscriptionEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeNats } -type SubscriptionSource struct { - pubSub Adapter +// RootFieldName returns the root field name +func (s *SubscriptionEventConfiguration) RootFieldName() string { + return s.FieldName } -func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } +type PublishAndRequestEventConfiguration struct { + Provider string `json:"providerId"` + Subject string `json:"subject"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } +// ProviderID returns the provider ID +func (p *PublishAndRequestEventConfiguration) ProviderID() string { + return p.Provider +} - _, err = xxh.Write(val) - return err +// ProviderType returns the provider type +func (p *PublishAndRequestEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeNats } -func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration SubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { - return err - } +// RootFieldName returns the root field name +func (p *PublishAndRequestEventConfiguration) RootFieldName() string { + return p.FieldName +} - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) +func (p *PublishAndRequestEventConfiguration) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + return fmt.Sprintf(`{"subject":"%s", "event": {"data": %s}, "providerId":"%s"}`, p.Subject, p.Event.Data, p.ProviderID()), nil } type NatsPublishDataSource struct { @@ -78,8 +83,7 @@ type NatsPublishDataSource struct { func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration PublishAndRequestEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { + if err := json.Unmarshal(input, &publishConfiguration); err != nil { return err } @@ -87,7 +91,7 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *byt _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(out, `{"success": true}`) + _, err := io.WriteString(out, `{"success": true}`) return err } @@ -101,8 +105,7 @@ type NatsRequestDataSource struct { func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var subscriptionConfiguration PublishAndRequestEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { + if err := json.Unmarshal(input, &subscriptionConfiguration); err != nil { return err } @@ -112,3 +115,8 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *byt func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { panic("not implemented") } + +// Interface compliance checks +var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) +var _ datasource.PublishEventConfiguration = (*PublishAndRequestEventConfiguration)(nil) +var _ datasource.StreamEvent = (*Event)(nil) diff --git a/router/pkg/pubsub/nats/engine_datasource_factory.go b/router/pkg/pubsub/nats/engine_datasource_factory.go index 48fd2849f7..36d3932e0d 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory.go @@ -5,6 +5,8 @@ import ( "fmt" "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -63,24 +65,44 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri subject := c.subjects[0] evtCfg := PublishAndRequestEventConfiguration{ - ProviderID: c.providerId, - Subject: subject, - Data: eventData, + Provider: c.providerId, + Subject: subject, + Event: Event{Data: eventData}, + FieldName: c.fieldName, } - return evtCfg.MarshalJSONTemplate(), nil + return evtCfg.MarshalJSONTemplate() } -func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { - return &SubscriptionSource{ - pubSub: c.NatsAdapter, - }, nil +func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.NatsAdapter, + func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "subjects") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err + }), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { evtCfg := SubscriptionEventConfiguration{ - ProviderID: c.providerId, - Subjects: c.subjects, + Provider: c.providerId, + Subjects: c.subjects, + FieldName: c.fieldName, } if c.withStreamConfiguration { evtCfg.StreamConfiguration = &StreamConfiguration{ diff --git a/router/pkg/pubsub/nats/engine_datasource_factory_test.go b/router/pkg/pubsub/nats/engine_datasource_factory_test.go index 57426ad34c..a94c8d5941 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory_test.go @@ -4,13 +4,16 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "testing" + "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestNatsEngineDataSourceFactory(t *testing.T) { @@ -34,7 +37,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Configure mock expectations for Publish mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Subject == "test-subject" + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" })).Return(nil) // Create the data source with mock adapter @@ -167,7 +170,7 @@ func TestEngineDataSourceFactory_RequestDataSource(t *testing.T) { // Configure mock expectations for Request mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Subject == "test-subject" + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { w := args.Get(2).(io.Writer) w.Write([]byte(`{"response": "test"}`)) @@ -253,3 +256,57 @@ func TestTransformEventConfig(t *testing.T) { assert.Contains(t, err.Error(), "invalid subject") }) } + +func TestNatsEngineDataSourceFactory_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing subjects", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"subjects":["subject1", "subject2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := &EngineDataSourceFactory{ + NatsAdapter: NewMockAdapter(t), + } + source, err := factory.ResolveDataSourceSubscription() + require.NoError(t, err) + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err = source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go index da21d4de88..5d060d2c0d 100644 --- a/router/pkg/pubsub/nats/engine_datasource_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -8,48 +8,11 @@ import ( "io" "testing" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { - tests := []struct { - name string - config PublishAndRequestEventConfiguration - wantPattern string - }{ - { - name: "simple configuration", - config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider", - Subject: "test-subject", - Data: json.RawMessage(`{"message":"hello"}`), - }, - wantPattern: `{"subject":"test-subject", "data": {"message":"hello"}, "providerId":"test-provider"}`, - }, - { - name: "with special characters", - config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider-id", - Subject: "subject-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), - }, - wantPattern: `{"subject":"subject-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.config.MarshalJSONTemplate() - assert.Equal(t, tt.wantPattern, result) - }) - } -} - func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string @@ -59,149 +22,32 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider", - Subject: "test-subject", - Data: json.RawMessage(`{"message":"hello"}`), + Provider: "test-provider", + Subject: "test-subject", + Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, }, - wantPattern: `{"subject":"test-subject", "data": {"message":"hello"}, "providerId":"test-provider"}`, + wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, }, { name: "with special characters", config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider-id", - Subject: "subject-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + Provider: "test-provider-id", + Subject: "subject-with-hyphens", + Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, }, - wantPattern: `{"subject":"subject-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := tt.config.MarshalJSONTemplate() + result, err := tt.config.MarshalJSONTemplate() + assert.NoError(t, err) assert.Equal(t, tt.wantPattern, result) }) } } -func TestSubscriptionSource_UniqueRequestID(t *testing.T) { - tests := []struct { - name string - input string - expectError bool - expectedError error - }{ - { - name: "valid input", - input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, - expectError: false, - }, - { - name: "missing subjects", - input: `{"providerId":"test-provider"}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - { - name: "missing providerId", - input: `{"subjects":["subject1", "subject2"]}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - source := &SubscriptionSource{ - pubSub: NewMockAdapter(t), - } - ctx := &resolve.Context{} - input := []byte(tt.input) - xxh := xxhash.New() - - err := source.UniqueRequestID(ctx, input, xxh) - - if tt.expectError { - require.Error(t, err) - if tt.expectedError != nil { - // For jsonparser errors, just check if the error message contains the expected text - assert.Contains(t, err.Error(), tt.expectedError.Error()) - } - } else { - require.NoError(t, err) - // Check that the hash has been updated - assert.NotEqual(t, 0, xxh.Sum64()) - } - }) - } -} - -func TestSubscriptionSource_Start(t *testing.T) { - tests := []struct { - name string - input string - mockSetup func(*MockAdapter, *datasource.MockSubscriptionUpdater) - expectError bool - }{ - { - name: "successful subscription", - input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Subjects: []string{"subject1", "subject2"}, - }, mock.Anything).Return(nil) - }, - expectError: false, - }, - { - name: "adapter returns error", - input: `{"subjects":["subject1"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Subjects: []string{"subject1"}, - }, mock.Anything).Return(errors.New("subscription error")) - }, - expectError: true, - }, - { - name: "invalid input json", - input: `{"invalid json":`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) {}, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) - updater := datasource.NewMockSubscriptionUpdater(t) - tt.mockSetup(mockAdapter, updater) - - source := &SubscriptionSource{ - pubSub: mockAdapter, - } - - // Set up go context - goCtx := context.Background() - - // Create a resolve.Context with the standard context - resolveCtx := &resolve.Context{} - resolveCtx = resolveCtx.WithContext(goCtx) - - input := []byte(tt.input) - err := source.Start(resolveCtx, input, updater) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } -} - func TestNatsPublishDataSource_Load(t *testing.T) { tests := []struct { name string @@ -213,12 +59,12 @@ func TestNatsPublishDataSource_Load(t *testing.T) { }{ { name: "successful publish", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` })).Return(nil) }, expectError: false, @@ -227,7 +73,7 @@ func TestNatsPublishDataSource_Load(t *testing.T) { }, { name: "publish error", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) }, @@ -288,12 +134,12 @@ func TestNatsRequestDataSource_Load(t *testing.T) { }{ { name: "successful request", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` }), mock.Anything).Run(func(args mock.Arguments) { // Write response to the output buffer w := args.Get(2).(io.Writer) @@ -305,7 +151,7 @@ func TestNatsRequestDataSource_Load(t *testing.T) { }, { name: "request error", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Request", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("request error")) }, diff --git a/router/pkg/pubsub/nats/mocks.go b/router/pkg/pubsub/nats/mocks.go index de49c6ae7e..0bc3ada5f0 100644 --- a/router/pkg/pubsub/nats/mocks.go +++ b/router/pkg/pubsub/nats/mocks.go @@ -9,7 +9,7 @@ import ( "io" mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) // NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. @@ -262,7 +262,7 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { ret := _mock.Called(ctx, event, updater) if len(ret) == 0 { @@ -270,7 +270,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, resolve.SubscriptionUpdater) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { r0 = returnFunc(ctx, event, updater) } else { r0 = ret.Error(0) @@ -285,25 +285,25 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event SubscriptionEventConfiguration -// - updater resolve.SubscriptionUpdater +// - event datasource.SubscriptionEventConfiguration +// - updater datasource.SubscriptionEventUpdater func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 SubscriptionEventConfiguration + var arg1 datasource.SubscriptionEventConfiguration if args[1] != nil { - arg1 = args[1].(SubscriptionEventConfiguration) + arg1 = args[1].(datasource.SubscriptionEventConfiguration) } - var arg2 resolve.SubscriptionUpdater + var arg2 datasource.SubscriptionEventUpdater if args[2] != nil { - arg2 = args[2].(resolve.SubscriptionUpdater) + arg2 = args[2].(datasource.SubscriptionEventUpdater) } run( arg0, @@ -319,7 +319,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c } diff --git a/router/pkg/pubsub/pubsub.go b/router/pkg/pubsub/pubsub.go index b92aaad6f7..085de71a0e 100644 --- a/router/pkg/pubsub/pubsub.go +++ b/router/pkg/pubsub/pubsub.go @@ -51,9 +51,23 @@ func (e *ProviderNotDefinedError) Error() string { return fmt.Sprintf("%s provider with ID %s is not defined", e.ProviderTypeID, e.ProviderID) } +// Hooks contains hooks for the pubsub providers and data sources +type Hooks struct { + SubscriptionOnStart []pubsub_datasource.SubscriptionOnStartFn +} + // BuildProvidersAndDataSources is a generic function that builds providers and data sources for the given // EventsConfiguration and DataSourceConfigurationWithMetadata -func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfiguration, store metric.StreamMetricStore, logger *zap.Logger, dsConfs []DataSourceConfigurationWithMetadata, hostName string, routerListenAddr string) ([]pubsub_datasource.Provider, []plan.DataSource, error) { +func BuildProvidersAndDataSources( + ctx context.Context, + config config.EventsConfiguration, + store metric.StreamMetricStore, + logger *zap.Logger, + dsConfs []DataSourceConfigurationWithMetadata, + hostName string, + routerListenAddr string, + hooks Hooks, +) ([]pubsub_datasource.Provider, []plan.DataSource, error) { if store == nil { store = metric.NewNoopStreamMetricStore() } @@ -70,7 +84,7 @@ func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfi events: dsConf.Configuration.GetCustomEvents().GetKafka(), }) } - kafkaPubSubProviders, kafkaOuts, err := build(ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents, store) + kafkaPubSubProviders, kafkaOuts, err := build(ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents, store, hooks) if err != nil { return nil, nil, err } @@ -86,7 +100,7 @@ func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfi events: dsConf.Configuration.GetCustomEvents().GetNats(), }) } - natsPubSubProviders, natsOuts, err := build(ctx, natsBuilder, config.Providers.Nats, natsDsConfsWithEvents, store) + natsPubSubProviders, natsOuts, err := build(ctx, natsBuilder, config.Providers.Nats, natsDsConfsWithEvents, store, hooks) if err != nil { return nil, nil, err } @@ -102,7 +116,7 @@ func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfi events: dsConf.Configuration.GetCustomEvents().GetRedis(), }) } - redisPubSubProviders, redisOuts, err := build(ctx, redisBuilder, config.Providers.Redis, redisDsConfsWithEvents, store) + redisPubSubProviders, redisOuts, err := build(ctx, redisBuilder, config.Providers.Redis, redisDsConfsWithEvents, store, hooks) if err != nil { return nil, nil, err } @@ -118,6 +132,7 @@ func build[P GetID, E GetEngineEventConfiguration]( providersData []P, dsConfs []dsConfAndEvents[E], store metric.StreamMetricStore, + hooks Hooks, ) ([]pubsub_datasource.Provider, []plan.DataSource, error) { var pubSubProviders []pubsub_datasource.Provider var outs []plan.DataSource @@ -161,7 +176,7 @@ func build[P GetID, E GetEngineEventConfiguration]( // build data sources for each event for _, dsConf := range dsConfs { for i, event := range dsConf.events { - plannerConfig := pubsub_datasource.NewPlannerConfig(builder, event) + plannerConfig := pubsub_datasource.NewPlannerConfig(builder, event, hooks.SubscriptionOnStart) out, err := plan.NewDataSourceConfiguration( dsConf.dsConf.Configuration.Id+"-"+builder.TypeID()+"-"+strconv.Itoa(i), pubsub_datasource.NewPlannerFactory(ctx, plannerConfig), diff --git a/router/pkg/pubsub/pubsub_test.go b/router/pkg/pubsub/pubsub_test.go index 2173e46c3d..976980b4ff 100644 --- a/router/pkg/pubsub/pubsub_test.go +++ b/router/pkg/pubsub/pubsub_test.go @@ -3,9 +3,10 @@ package pubsub import ( "context" "errors" + "testing" + "github.com/stretchr/testify/mock" rmetric "github.com/wundergraph/cosmo/router/pkg/metric" - "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -67,7 +68,7 @@ func TestBuild_OK(t *testing.T) { // ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.NoError(t, err) @@ -123,7 +124,7 @@ func TestBuild_ProviderError(t *testing.T) { mockBuilder.On("BuildProvider", natsEventSources[0], mock.Anything).Return(nil, errors.New("provider error")) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.Error(t, err) @@ -178,7 +179,7 @@ func TestBuild_ShouldGetAnErrorIfProviderIsNotDefined(t *testing.T) { mockBuilder.On("TypeID").Return("nats") // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.Error(t, err) @@ -242,7 +243,7 @@ func TestBuild_ShouldNotInitializeProviderIfNotUsed(t *testing.T) { Return(mockPubSubUsedProvider, nil) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.NoError(t, err) @@ -293,7 +294,7 @@ func TestBuildProvidersAndDataSources_Nats_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr") + }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) // Assertions assert.NoError(t, err) @@ -346,7 +347,7 @@ func TestBuildProvidersAndDataSources_Kafka_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr") + }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) // Assertions assert.NoError(t, err) @@ -399,7 +400,7 @@ func TestBuildProvidersAndDataSources_Redis_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr") + }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) // Assertions assert.NoError(t, err) diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 8de962d2b6..5cb0055a36 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -9,7 +9,6 @@ import ( rd "github.com/wundergraph/cosmo/router/internal/rediscloser" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -21,7 +20,7 @@ const ( // Adapter defines the methods that a Redis adapter should implement type Adapter interface { // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error // Publish publishes the given event to the specified channel Publish(ctx context.Context, event PublishEventConfiguration) error // Startup initializes the adapter @@ -94,19 +93,23 @@ func (p *ProviderAdapter) Shutdown(ctx context.Context) error { return p.conn.Close() } -func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := conf.(*SubscriptionEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), - zap.Strings("channels", event.Channels), + zap.Strings("channels", subConf.Channels), ) - sub := p.conn.PSubscribe(ctx, event.Channels...) + sub := p.conn.PSubscribe(ctx, subConf.Channels...) msgChan := sub.Channel() cleanup := func() { - err := sub.PUnsubscribe(ctx, event.Channels...) + err := sub.PUnsubscribe(ctx, subConf.Channels...) if err != nil { - log.Error(fmt.Sprintf("error unsubscribing from redis for topics %v", event.Channels), zap.Error(err)) + log.Error(fmt.Sprintf("error unsubscribing from redis for topics %v", subConf.Channels), zap.Error(err)) } } @@ -128,12 +131,14 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } log.Debug("subscription update", zap.String("message_channel", msg.Channel), zap.String("data", msg.Payload)) p.streamMetricStore.Consume(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: conf.ProviderID(), StreamOperationName: redisReceive, ProviderType: metric.ProviderTypeRedis, DestinationName: msg.Channel, }) - updater.Update([]byte(msg.Payload)) + updater.Update(&Event{ + Data: []byte(msg.Payload), + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscription if it is not already done log.Debug("application context done, stopping subscription") @@ -153,14 +158,14 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "publish"), zap.String("channel", event.Channel), ) - log.Debug("publish", zap.ByteString("data", event.Data)) + log.Debug("publish", zap.ByteString("data", event.Event.Data)) - data, dataErr := event.Data.MarshalJSON() + data, dataErr := event.Event.Data.MarshalJSON() if dataErr != nil { log.Error("error marshalling data", zap.Error(dataErr)) return datasource.NewError("error marshalling data", dataErr) @@ -172,7 +177,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu if intCmd.Err() != nil { log.Error("publish error", zap.Error(intCmd.Err())) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: redisPublish, ProviderType: metric.ProviderTypeRedis, ErrorType: "publish_error", @@ -182,7 +187,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: redisPublish, ProviderType: metric.ProviderTypeRedis, DestinationName: event.Channel, diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index d24a4fb959..3a685fe9b0 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -7,69 +7,66 @@ import ( "fmt" "io" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// SubscriptionEventConfiguration contains configuration for subscription events -type SubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` - Channels []string `json:"channels"` +// Event represents an event from Redis +type Event struct { + Data json.RawMessage `json:"data"` } -// PublishEventConfiguration contains configuration for publish events -type PublishEventConfiguration struct { - ProviderID string `json:"providerId"` - Channel string `json:"channel"` - Data json.RawMessage `json:"data"` +func (e *Event) GetData() []byte { + return e.Data } -func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { - return fmt.Sprintf(`{"channel":"%s", "data": %s, "providerId":"%s"}`, s.Channel, s.Data, s.ProviderID), nil +// SubscriptionEventConfiguration contains configuration for subscription events +type SubscriptionEventConfiguration struct { + Provider string `json:"providerId"` + Channels []string `json:"channels"` + FieldName string `json:"rootFieldName"` } -// SubscriptionDataSource implements resolve.SubscriptionDataSource for Redis -type SubscriptionDataSource struct { - pubSub Adapter +// ProviderID returns the provider ID +func (s *SubscriptionEventConfiguration) ProviderID() string { + return s.Provider } -// UniqueRequestID computes a unique ID for the subscription request -func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - val, _, _, err := jsonparser.Get(input, "channels") - if err != nil { - return err - } +// ProviderType returns the provider type +func (s *SubscriptionEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeRedis +} - _, err = xxh.Write(val) - if err != nil { - return err - } +// RootFieldName returns the root field name +func (s *SubscriptionEventConfiguration) RootFieldName() string { + return s.FieldName +} - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } +// PublishEventConfiguration contains configuration for publish events +type PublishEventConfiguration struct { + Provider string `json:"providerId"` + Channel string `json:"channel"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} - _, err = xxh.Write(val) - return err +// ProviderID returns the provider ID +func (p *PublishEventConfiguration) ProviderID() string { + return p.Provider } -// Start starts the subscription -func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration SubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { - return err - } +// ProviderType returns the provider type +func (p *PublishEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeRedis +} - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) +// RootFieldName returns the root field name +func (p *PublishEventConfiguration) RootFieldName() string { + return p.FieldName } -// LoadInitialData implements the interface method (not used for this subscription type) -func (s *SubscriptionDataSource) LoadInitialData(ctx context.Context) (initial []byte, err error) { - return nil, nil +func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { + return fmt.Sprintf(`{"channel":"%s", "event": {"data": %s}, "providerId":"%s"}`, s.Channel, s.Event.Data, s.ProviderID()), nil } // PublishDataSource implements resolve.DataSource for Redis publishing @@ -80,8 +77,7 @@ type PublishDataSource struct { // Load processes a request to publish to Redis func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration PublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { + if err := json.Unmarshal(input, &publishConfiguration); err != nil { return err } @@ -89,7 +85,7 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(out, `{"success": true}`) + _, err := io.WriteString(out, `{"success": true}`) return err } @@ -97,3 +93,8 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { panic("not implemented") } + +// Interface compliance checks +var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) +var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) +var _ datasource.StreamEvent = (*Event)(nil) diff --git a/router/pkg/pubsub/redis/engine_datasource_factory.go b/router/pkg/pubsub/redis/engine_datasource_factory.go index c5383ff16a..bce913e54e 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory.go @@ -5,6 +5,8 @@ import ( "fmt" "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -59,26 +61,46 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri providerId := c.providerId evtCfg := PublishEventConfiguration{ - ProviderID: providerId, - Channel: channel, - Data: eventData, + Provider: providerId, + Channel: channel, + Event: Event{Data: eventData}, + FieldName: c.fieldName, } return evtCfg.MarshalJSONTemplate() } // ResolveDataSourceSubscription returns the subscription data source -func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { - return &SubscriptionDataSource{ - pubSub: c.RedisAdapter, - }, nil +func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.RedisAdapter, + func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "channels") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err + }), nil } // ResolveDataSourceSubscriptionInput builds the input for the subscription data source func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { evtCfg := SubscriptionEventConfiguration{ - ProviderID: c.providerId, - Channels: c.channels, + Provider: c.providerId, + Channels: c.channels, + FieldName: c.fieldName, } object, err := json.Marshal(evtCfg) if err != nil { diff --git a/router/pkg/pubsub/redis/engine_datasource_factory_test.go b/router/pkg/pubsub/redis/engine_datasource_factory_test.go index 0c1344048a..f96691583d 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory_test.go @@ -4,11 +4,15 @@ import ( "bytes" "context" "encoding/json" + "errors" "testing" + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestRedisEngineDataSourceFactory(t *testing.T) { @@ -33,7 +37,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Configure mock expectations for Publish mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Channel == "test-channel" + return event.ProviderID() == "test-provider" && event.Channel == "test-channel" })).Return(nil) // Create the data source with mock adapter @@ -176,3 +180,57 @@ func TestTransformEventConfig(t *testing.T) { require.Equal(t, []string{"transformed.original.subject1", "transformed.original.subject2"}, cfg.channels) }) } + +func TestRedisEngineDataSourceFactory_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"channels":["channel1", "channel2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing channels", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"channels":["channel1", "channel2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := &EngineDataSourceFactory{ + RedisAdapter: NewMockAdapter(t), + } + source, err := factory.ResolveDataSourceSubscription() + require.NoError(t, err) + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err = source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} diff --git a/router/pkg/pubsub/redis/engine_datasource_test.go b/router/pkg/pubsub/redis/engine_datasource_test.go index 7c47d47cc6..74b7d564d7 100644 --- a/router/pkg/pubsub/redis/engine_datasource_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_test.go @@ -7,12 +7,9 @@ import ( "errors" "testing" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { @@ -24,20 +21,20 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: PublishEventConfiguration{ - ProviderID: "test-provider", - Channel: "test-channel", - Data: json.RawMessage(`{"message":"hello"}`), + Provider: "test-provider", + Channel: "test-channel", + Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, }, - wantPattern: `{"channel":"test-channel", "data": {"message":"hello"}, "providerId":"test-provider"}`, + wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, }, { name: "with special characters", config: PublishEventConfiguration{ - ProviderID: "test-provider-id", - Channel: "channel-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + Provider: "test-provider-id", + Channel: "channel-with-hyphens", + Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, }, - wantPattern: `{"channel":"channel-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, }, } @@ -50,124 +47,6 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { } } -func TestSubscriptionSource_UniqueRequestID(t *testing.T) { - tests := []struct { - name string - input string - expectError bool - expectedError error - }{ - { - name: "valid input", - input: `{"channels":["channel1", "channel2"], "providerId":"test-provider"}`, - expectError: false, - }, - { - name: "missing channels", - input: `{"providerId":"test-provider"}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - { - name: "missing providerId", - input: `{"channels":["channel1", "channel2"]}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - source := &SubscriptionDataSource{ - pubSub: NewMockAdapter(t), - } - ctx := &resolve.Context{} - input := []byte(tt.input) - xxh := xxhash.New() - - err := source.UniqueRequestID(ctx, input, xxh) - - if tt.expectError { - require.Error(t, err) - if tt.expectedError != nil { - // For jsonparser errors, just check if the error message contains the expected text - assert.Contains(t, err.Error(), tt.expectedError.Error()) - } - } else { - require.NoError(t, err) - // Check that the hash has been updated - assert.NotEqual(t, 0, xxh.Sum64()) - } - }) - } -} - -func TestSubscriptionSource_Start(t *testing.T) { - tests := []struct { - name string - input string - mockSetup func(*MockAdapter, *datasource.MockSubscriptionUpdater) - expectError bool - }{ - { - name: "successful subscription", - input: `{"channels":["channel1", "channel2"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Channels: []string{"channel1", "channel2"}, - }, mock.Anything).Return(nil) - }, - expectError: false, - }, - { - name: "adapter returns error", - input: `{"channels":["channel1"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Channels: []string{"channel1"}, - }, mock.Anything).Return(errors.New("subscription error")) - }, - expectError: true, - }, - { - name: "invalid input json", - input: `{"invalid json":`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) {}, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) - updater := datasource.NewMockSubscriptionUpdater(t) - tt.mockSetup(mockAdapter, updater) - - source := &SubscriptionDataSource{ - pubSub: mockAdapter, - } - - // Set up go context - goCtx := context.Background() - - // Create a resolve.Context with the standard context - resolveCtx := &resolve.Context{} - resolveCtx = resolveCtx.WithContext(goCtx) - - input := []byte(tt.input) - err := source.Start(resolveCtx, input, updater) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } -} - func TestRedisPublishDataSource_Load(t *testing.T) { tests := []struct { name string @@ -179,12 +58,12 @@ func TestRedisPublishDataSource_Load(t *testing.T) { }{ { name: "successful publish", - input: `{"channel":"test-channel", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Channel == "test-channel" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` })).Return(nil) }, expectError: false, @@ -193,7 +72,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { }, { name: "publish error", - input: `{"channel":"test-channel", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) }, diff --git a/router/pkg/pubsub/redis/mocks.go b/router/pkg/pubsub/redis/mocks.go index 603a5dd548..6f6938cdd0 100644 --- a/router/pkg/pubsub/redis/mocks.go +++ b/router/pkg/pubsub/redis/mocks.go @@ -8,7 +8,7 @@ import ( "context" mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) // NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. @@ -198,7 +198,7 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { ret := _mock.Called(ctx, event, updater) if len(ret) == 0 { @@ -206,7 +206,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, resolve.SubscriptionUpdater) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { r0 = returnFunc(ctx, event, updater) } else { r0 = ret.Error(0) @@ -221,25 +221,25 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event SubscriptionEventConfiguration -// - updater resolve.SubscriptionUpdater +// - event datasource.SubscriptionEventConfiguration +// - updater datasource.SubscriptionEventUpdater func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 SubscriptionEventConfiguration + var arg1 datasource.SubscriptionEventConfiguration if args[1] != nil { - arg1 = args[1].(SubscriptionEventConfiguration) + arg1 = args[1].(datasource.SubscriptionEventConfiguration) } - var arg2 resolve.SubscriptionUpdater + var arg2 datasource.SubscriptionEventUpdater if args[2] != nil { - arg2 = args[2].(resolve.SubscriptionUpdater) + arg2 = args[2].(datasource.SubscriptionEventUpdater) } run( arg0, @@ -255,7 +255,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c }