diff --git a/v2/pkg/engine/datasource/graphql_datasource/configuration.go b/v2/pkg/engine/datasource/graphql_datasource/configuration.go index c5a7f3f361..ec398c0e05 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/configuration.go +++ b/v2/pkg/engine/datasource/graphql_datasource/configuration.go @@ -9,6 +9,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" grpcdatasource "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/grpc_datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/federation" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) @@ -103,6 +104,11 @@ type SingleTypeField struct { FieldName string } +// SubscriptionOnStartFn defines a hook function that is called when a subscription starts. +// It receives the resolve context and the input of the subscription. +// The function can return an error. +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, input []byte) (err error) + type SubscriptionConfiguration struct { URL string Header http.Header @@ -119,6 +125,8 @@ type SubscriptionConfiguration struct { // these headers by itself. ForwardedClientHeaderRegularExpressions []RegularExpression WsSubProtocol string + // StartupHooks contains the method called when a subscription is started + StartupHooks []SubscriptionOnStartFn } type FetchConfiguration struct { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 3acdd07603..53c4778cbc 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -449,7 +449,8 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { return plan.SubscriptionConfiguration{ Input: string(input), DataSource: &SubscriptionSource{ - client: p.subscriptionClient, + client: p.subscriptionClient, + subscriptionOnStartFns: p.config.subscription.StartupHooks, }, Variables: p.variables, PostProcessing: DefaultPostProcessingConfiguration, @@ -1953,7 +1954,8 @@ type RegularExpression struct { } type SubscriptionSource struct { - client GraphQLSubscriptionClient + client GraphQLSubscriptionClient + subscriptionOnStartFns []SubscriptionOnStartFn } func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, input []byte, updater resolve.SubscriptionUpdater) error { @@ -2003,3 +2005,12 @@ func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, } return s.client.UniqueRequestID(ctx, options, xxh) } + +// SubscriptionOnStart is called when a subscription is started. +// Each hook is called in a separate goroutine. +func (s *SubscriptionSource) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { + for _, fn := range s.subscriptionOnStartFns { + return fn(ctx, input) + } + return nil +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index f7031fc3a3..1077daa1ad 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -4018,7 +4018,7 @@ func TestGraphQLDataSource(t *testing.T) { Trigger: resolve.GraphQLSubscriptionTrigger{ Input: []byte(`{"url":"wss://swapi.com/graphql","body":{"query":"subscription{remainingJedis}"}}`), Source: &SubscriptionSource{ - NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), + client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, }, @@ -8904,6 +8904,60 @@ func TestSanitizeKey(t *testing.T) { } } +func TestSubscriptionSource_SubscriptionOnStart(t *testing.T) { + + t.Run("SubscriptionOnStart calls subscriptionOnStartFns", func(t *testing.T) { + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + type fnData struct { + ctx resolve.StartupHookContext + input []byte + } + + startFnCalled := make(chan fnData, 1) + subscriptionSource := SubscriptionSource{ + subscriptionOnStartFns: []SubscriptionOnStartFn{ + func(ctx resolve.StartupHookContext, input []byte) error { + startFnCalled <- fnData{ctx, input} + return nil + }, + }, + } + + err := subscriptionSource.SubscriptionOnStart(ctx, []byte(`{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`)) + require.NoError(t, err) + var called fnData + select { + case called = <-startFnCalled: + case <-time.After(1 * time.Second): + t.Fatal("SubscriptionOnStartFn was not called") + } + assert.Equal(t, []byte(`{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`), called.input) + }) + + t.Run("SubscriptionOnStart calls subscriptionOnStartFns and returns error if one of the functions returns an error", func(t *testing.T) { + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + subscriptionSource := SubscriptionSource{ + subscriptionOnStartFns: []SubscriptionOnStartFn{ + func(ctx resolve.StartupHookContext, input []byte) error { + return errors.New("test error") + }, + }, + } + + err := subscriptionSource.SubscriptionOnStart(ctx, []byte(`{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`)) + require.Error(t, err) + assert.ErrorContains(t, err, "test error") + }) +} + const interfaceSelectionSchema = ` scalar String diff --git a/v2/pkg/engine/datasource/pubsub_datasource/kafka_event_manager.go b/v2/pkg/engine/datasource/pubsub_datasource/kafka_event_manager.go deleted file mode 100644 index 80b02e1165..0000000000 --- a/v2/pkg/engine/datasource/pubsub_datasource/kafka_event_manager.go +++ /dev/null @@ -1,72 +0,0 @@ -package pubsub_datasource - -import ( - "encoding/json" - "fmt" - "slices" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -type KafkaSubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` - Topics []string `json:"topics"` -} - -type KafkaPublishEventConfiguration struct { - ProviderID string `json:"providerId"` - Topic string `json:"topic"` - Data json.RawMessage `json:"data"` -} - -func (s *KafkaPublishEventConfiguration) MarshalJSONTemplate() string { - return fmt.Sprintf(`{"topic":"%s", "data": %s, "providerId":"%s"}`, s.Topic, s.Data, s.ProviderID) -} - -type KafkaEventManager struct { - visitor *plan.Visitor - variables *resolve.Variables - eventMetadata EventMetadata - eventConfiguration *KafkaEventConfiguration - publishEventConfiguration *KafkaPublishEventConfiguration - subscriptionEventConfiguration *KafkaSubscriptionEventConfiguration -} - -func (p *KafkaEventManager) eventDataBytes(ref int) ([]byte, error) { - return buildEventDataBytes(ref, p.visitor, p.variables) -} - -func (p *KafkaEventManager) handlePublishEvent(ref int) { - if len(p.eventConfiguration.Topics) != 1 { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("publish and request events should define one subject but received %d", len(p.eventConfiguration.Topics))) - return - } - topic := p.eventConfiguration.Topics[0] - dataBytes, err := p.eventDataBytes(ref) - if err != nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to write event data bytes: %w", err)) - return - } - - p.publishEventConfiguration = &KafkaPublishEventConfiguration{ - ProviderID: p.eventMetadata.ProviderID, - Topic: topic, - Data: dataBytes, - } -} - -func (p *KafkaEventManager) handleSubscriptionEvent(ref int) { - - if len(p.eventConfiguration.Topics) == 0 { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("expected at least one subscription topic but received %d", len(p.eventConfiguration.Topics))) - return - } - - slices.Sort(p.eventConfiguration.Topics) - - p.subscriptionEventConfiguration = &KafkaSubscriptionEventConfiguration{ - ProviderID: p.eventMetadata.ProviderID, - Topics: p.eventConfiguration.Topics, - } -} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/nats_event_manager.go b/v2/pkg/engine/datasource/pubsub_datasource/nats_event_manager.go deleted file mode 100644 index 04fc7e70ad..0000000000 --- a/v2/pkg/engine/datasource/pubsub_datasource/nats_event_manager.go +++ /dev/null @@ -1,189 +0,0 @@ -package pubsub_datasource - -import ( - "encoding/json" - "fmt" - "regexp" - "slices" - "strings" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/argument_templates" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -const ( - fwc = '>' - tsep = "." -) - -// A variable template has form $$number$$ where the number can range from one to multiple digits -var ( - variableTemplateRegex = regexp.MustCompile(`\$\$\d+\$\$`) -) - -type NatsSubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` - Subjects []string `json:"subjects"` - StreamConfiguration *NatsStreamConfiguration `json:"streamConfiguration,omitempty"` -} - -type NatsPublishAndRequestEventConfiguration struct { - ProviderID string `json:"providerId"` - Subject string `json:"subject"` - Data json.RawMessage `json:"data"` -} - -func (s *NatsPublishAndRequestEventConfiguration) MarshalJSONTemplate() string { - return fmt.Sprintf(`{"subject":"%s", "data": %s, "providerId":"%s"}`, s.Subject, s.Data, s.ProviderID) -} - -type NatsEventManager struct { - visitor *plan.Visitor - variables *resolve.Variables - eventMetadata EventMetadata - eventConfiguration *NatsEventConfiguration - publishAndRequestEventConfiguration *NatsPublishAndRequestEventConfiguration - subscriptionEventConfiguration *NatsSubscriptionEventConfiguration -} - -func isValidNatsSubject(subject string) bool { - if subject == "" { - return false - } - sfwc := false - tokens := strings.Split(subject, tsep) - for _, t := range tokens { - length := len(t) - if length == 0 || sfwc { - return false - } - if length > 1 { - if strings.ContainsAny(t, "\t\n\f\r ") { - return false - } - continue - } - switch t[0] { - case fwc: - sfwc = true - case ' ', '\t', '\n', '\r', '\f': - return false - } - } - return true -} - -func (p *NatsEventManager) addContextVariableByArgumentRef(argumentRef int, argumentPath []string) (string, error) { - variablePath, err := p.visitor.Operation.VariablePathByArgumentRefAndArgumentPath(argumentRef, argumentPath, p.visitor.Walker.Ancestors[0].Ref) - if err != nil { - return "", err - } - /* The definition is passed as both definition and operation below because getJSONRootType resolves the type - * from the first argument, but finalInputValueTypeRef comes from the definition - */ - contextVariable := &resolve.ContextVariable{ - Path: variablePath, - Renderer: resolve.NewPlainVariableRenderer(), - } - variablePlaceHolder, _ := p.variables.AddVariable(contextVariable) - return variablePlaceHolder, nil -} - -func (p *NatsEventManager) extractEventSubject(fieldRef int, subject string) (string, error) { - matches := argument_templates.ArgumentTemplateRegex.FindAllStringSubmatch(subject, -1) - // If no argument templates are defined, there are only static values - if len(matches) < 1 { - if isValidNatsSubject(subject) { - return subject, nil - } - return "", fmt.Errorf(`subject "%s" is not a valid NATS subject`, subject) - } - fieldNameBytes := p.visitor.Operation.FieldNameBytes(fieldRef) - // TODO: handling for interfaces and unions - fieldDefinitionRef, ok := p.visitor.Definition.ObjectTypeDefinitionFieldWithName(p.visitor.Walker.EnclosingTypeDefinition.Ref, fieldNameBytes) - if !ok { - return "", fmt.Errorf(`expected field definition to exist for field "%s"`, fieldNameBytes) - } - subjectWithVariableTemplateReplacements := subject - for templateNumber, groups := range matches { - // The first group is the whole template; the second is the period delimited argument path - if len(groups) != 2 { - return "", fmt.Errorf(`argument template #%d defined on field "%s" is invalid: expected 2 matching groups but received %d`, templateNumber+1, fieldNameBytes, len(groups)-1) - } - validationResult, err := argument_templates.ValidateArgumentPath(p.visitor.Definition, groups[1], fieldDefinitionRef) - if err != nil { - return "", fmt.Errorf(`argument template #%d defined on field "%s" is invalid: %w`, templateNumber+1, fieldNameBytes, err) - } - argumentNameBytes := []byte(validationResult.ArgumentPath[0]) - argumentRef, ok := p.visitor.Operation.FieldArgument(fieldRef, argumentNameBytes) - if !ok { - return "", fmt.Errorf(`operation field "%s" does not define argument "%s"`, fieldNameBytes, argumentNameBytes) - } - // variablePlaceholder has the form $$0$$, $$1$$, etc. - variablePlaceholder, err := p.addContextVariableByArgumentRef(argumentRef, validationResult.ArgumentPath) - if err != nil { - return "", fmt.Errorf(`failed to retrieve variable placeholder for argument ""%s" defined on operation field "%s": %w`, argumentNameBytes, fieldNameBytes, err) - } - // Replace the template literal with the variable placeholder (and reuse the variable if it already exists) - subjectWithVariableTemplateReplacements = strings.ReplaceAll(subjectWithVariableTemplateReplacements, groups[0], variablePlaceholder) - } - // Substitute the variable templates for dummy values to check naïvely that the string is a valid NATS subject - if isValidNatsSubject(variableTemplateRegex.ReplaceAllLiteralString(subjectWithVariableTemplateReplacements, "a")) { - return subjectWithVariableTemplateReplacements, nil - } - return "", fmt.Errorf(`subject "%s" is not a valid NATS subject`, subject) -} - -func (p *NatsEventManager) eventDataBytes(ref int) ([]byte, error) { - return buildEventDataBytes(ref, p.visitor, p.variables) -} - -func (p *NatsEventManager) handlePublishAndRequestEvent(ref int) { - if len(p.eventConfiguration.Subjects) != 1 { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("publish and request events should define one subject but received %d", len(p.eventConfiguration.Subjects))) - return - } - rawSubject := p.eventConfiguration.Subjects[0] - extractedSubject, err := p.extractEventSubject(ref, rawSubject) - if err != nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("could not extract event subject: %w", err)) - return - } - dataBytes, err := p.eventDataBytes(ref) - if err != nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to write event data bytes: %w", err)) - return - } - - p.publishAndRequestEventConfiguration = &NatsPublishAndRequestEventConfiguration{ - ProviderID: p.eventMetadata.ProviderID, - Subject: extractedSubject, - Data: dataBytes, - } -} - -func (p *NatsEventManager) handleSubscriptionEvent(ref int) { - - if len(p.eventConfiguration.Subjects) == 0 { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("expected at least one subscription subject but received %d", len(p.eventConfiguration.Subjects))) - return - } - extractedSubjects := make([]string, 0, len(p.eventConfiguration.Subjects)) - for _, rawSubject := range p.eventConfiguration.Subjects { - extractedSubject, err := p.extractEventSubject(ref, rawSubject) - if err != nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("could not extract subscription event subjects: %w", err)) - return - } - extractedSubjects = append(extractedSubjects, extractedSubject) - } - - slices.Sort(extractedSubjects) - - p.subscriptionEventConfiguration = &NatsSubscriptionEventConfiguration{ - ProviderID: p.eventMetadata.ProviderID, - Subjects: extractedSubjects, - StreamConfiguration: p.eventConfiguration.StreamConfiguration, - } -} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource.go deleted file mode 100644 index bdf12304dd..0000000000 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource.go +++ /dev/null @@ -1,346 +0,0 @@ -package pubsub_datasource - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "regexp" - "strings" - - "github.com/jensneuse/abstractlogger" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -type EventType string - -const ( - EventTypePublish EventType = "publish" - EventTypeRequest EventType = "request" - EventTypeSubscribe EventType = "subscribe" -) - -var eventSubjectRegex = regexp.MustCompile(`{{ args.([a-zA-Z0-9_]+) }}`) - -func EventTypeFromString(s string) (EventType, error) { - et := EventType(strings.ToLower(s)) - switch et { - case EventTypePublish, EventTypeRequest, EventTypeSubscribe: - return et, nil - default: - return "", fmt.Errorf("invalid event type: %q", s) - } -} - -type EventMetadata struct { - ProviderID string `json:"providerId"` - Type EventType `json:"type"` - TypeName string `json:"typeName"` - FieldName string `json:"fieldName"` -} - -type EventConfiguration struct { - Metadata *EventMetadata `json:"metadata"` - Configuration any `json:"configuration"` -} - -type Configuration struct { - Events []EventConfiguration `json:"events"` -} - -type Planner[T Configuration] struct { - id int - config Configuration - natsPubSubByProviderID map[string]NatsPubSub - kafkaPubSubByProviderID map[string]KafkaPubSub - eventManager any - rootFieldRef int - variables resolve.Variables - visitor *plan.Visitor -} - -func (p *Planner[T]) SetID(id int) { - p.id = id -} - -func (p *Planner[T]) ID() (id int) { - return p.id -} - -func (p *Planner[T]) EnterField(ref int) { - if p.rootFieldRef != -1 { - // This is a nested field; nothing needs to be done - return - } - p.rootFieldRef = ref - - fieldName := p.visitor.Operation.FieldNameString(ref) - typeName := p.visitor.Walker.EnclosingTypeDefinition.NameString(p.visitor.Definition) - - var eventConfig *EventConfiguration - for _, cfg := range p.config.Events { - if cfg.Metadata.TypeName == typeName && cfg.Metadata.FieldName == fieldName { - eventConfig = &cfg - break - } - } - if eventConfig == nil { - return - } - - switch v := eventConfig.Configuration.(type) { - case *NatsEventConfiguration: - em := &NatsEventManager{ - visitor: p.visitor, - variables: &p.variables, - eventMetadata: *eventConfig.Metadata, - eventConfiguration: v, - } - p.eventManager = em - - switch eventConfig.Metadata.Type { - case EventTypePublish, EventTypeRequest: - em.handlePublishAndRequestEvent(ref) - case EventTypeSubscribe: - em.handleSubscriptionEvent(ref) - default: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("invalid EventType \"%s\" for Nats", eventConfig.Metadata.Type)) - } - case *KafkaEventConfiguration: - em := &KafkaEventManager{ - visitor: p.visitor, - variables: &p.variables, - eventMetadata: *eventConfig.Metadata, - eventConfiguration: v, - } - p.eventManager = em - - switch eventConfig.Metadata.Type { - case EventTypePublish: - em.handlePublishEvent(ref) - case EventTypeSubscribe: - em.handleSubscriptionEvent(ref) - default: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("invalid EventType \"%s\" for Kafka", eventConfig.Metadata.Type)) - } - default: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("invalid event configuration type: %T", v)) - } -} - -func (p *Planner[T]) EnterDocument(_, _ *ast.Document) { - p.rootFieldRef = -1 - p.eventManager = nil -} - -func (p *Planner[T]) Register(visitor *plan.Visitor, configuration plan.DataSourceConfiguration[T], dataSourcePlannerConfiguration plan.DataSourcePlannerConfiguration) error { - p.visitor = visitor - visitor.Walker.RegisterEnterFieldVisitor(p) - visitor.Walker.RegisterEnterDocumentVisitor(p) - p.config = Configuration(configuration.CustomConfiguration()) - return nil -} - -func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { - if p.eventManager == nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to configure fetch: event manager is nil")) - return resolve.FetchConfiguration{} - } - - var dataSource resolve.DataSource - - switch v := p.eventManager.(type) { - case *NatsEventManager: - pubsub, ok := p.natsPubSubByProviderID[v.eventMetadata.ProviderID] - if !ok { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("no pubsub connection exists with provider id \"%s\"", v.eventMetadata.ProviderID)) - return resolve.FetchConfiguration{} - } - - switch v.eventMetadata.Type { - case EventTypePublish: - dataSource = &NatsPublishDataSource{ - pubSub: pubsub, - } - case EventTypeRequest: - dataSource = &NatsRequestDataSource{ - pubSub: pubsub, - } - default: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to configure fetch: invalid event type \"%s\" for Nats", v.eventMetadata.Type)) - return resolve.FetchConfiguration{} - } - - return resolve.FetchConfiguration{ - Input: v.publishAndRequestEventConfiguration.MarshalJSONTemplate(), - Variables: p.variables, - DataSource: dataSource, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{v.eventMetadata.FieldName}, - }, - } - - case *KafkaEventManager: - pubsub, ok := p.kafkaPubSubByProviderID[v.eventMetadata.ProviderID] - if !ok { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("no pubsub connection exists with provider id \"%s\"", v.eventMetadata.ProviderID)) - return resolve.FetchConfiguration{} - } - - switch v.eventMetadata.Type { - case EventTypePublish: - dataSource = &KafkaPublishDataSource{ - pubSub: pubsub, - } - case EventTypeRequest: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("event type \"%s\" is not supported for Kafka", v.eventMetadata.Type)) - return resolve.FetchConfiguration{} - default: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to configure fetch: invalid event type \"%s\" for Kafka", v.eventMetadata.Type)) - return resolve.FetchConfiguration{} - } - - return resolve.FetchConfiguration{ - Input: v.publishEventConfiguration.MarshalJSONTemplate(), - Variables: p.variables, - DataSource: dataSource, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{v.eventMetadata.FieldName}, - }, - } - - default: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to configure fetch: invalid event manager type: %T", p.eventManager)) - } - - return resolve.FetchConfiguration{} -} - -func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { - if p.eventManager == nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to configure subscription: event manager is nil")) - return plan.SubscriptionConfiguration{} - } - - switch v := p.eventManager.(type) { - case *NatsEventManager: - pubsub, ok := p.natsPubSubByProviderID[v.eventMetadata.ProviderID] - if !ok { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("no pubsub connection exists with provider id \"%s\"", v.eventMetadata.ProviderID)) - return plan.SubscriptionConfiguration{} - } - object, err := json.Marshal(v.subscriptionEventConfiguration) - if err != nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to marshal event subscription streamConfiguration")) - return plan.SubscriptionConfiguration{} - } - return plan.SubscriptionConfiguration{ - Input: string(object), - Variables: p.variables, - DataSource: &NatsSubscriptionSource{ - pubSub: pubsub, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{v.eventMetadata.FieldName}, - }, - } - case *KafkaEventManager: - pubsub, ok := p.kafkaPubSubByProviderID[v.eventMetadata.ProviderID] - if !ok { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("no pubsub connection exists with provider id \"%s\"", v.eventMetadata.ProviderID)) - return plan.SubscriptionConfiguration{} - } - object, err := json.Marshal(v.subscriptionEventConfiguration) - if err != nil { - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to marshal event subscription streamConfiguration")) - return plan.SubscriptionConfiguration{} - } - return plan.SubscriptionConfiguration{ - Input: string(object), - Variables: p.variables, - DataSource: &KafkaSubscriptionSource{ - pubSub: pubsub, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{v.eventMetadata.FieldName}, - }, - } - default: - p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to configure subscription: invalid event manager type: %T", p.eventManager)) - } - - return plan.SubscriptionConfiguration{} -} - -func (p *Planner[T]) DownstreamResponseFieldAlias(_ int) (alias string, exists bool) { - return "", false -} - -func NewFactory[T Configuration](executionContext context.Context, natsPubSubByProviderID map[string]NatsPubSub, kafkaPubSubByProviderID map[string]KafkaPubSub) *Factory[T] { - return &Factory[T]{ - executionContext: executionContext, - natsPubSubByProviderID: natsPubSubByProviderID, - kafkaPubSubByProviderID: kafkaPubSubByProviderID, - } -} - -type Factory[T Configuration] struct { - executionContext context.Context - natsPubSubByProviderID map[string]NatsPubSub - kafkaPubSubByProviderID map[string]KafkaPubSub -} - -func (f *Factory[T]) Planner(_ abstractlogger.Logger) plan.DataSourcePlanner[T] { - return &Planner[T]{ - natsPubSubByProviderID: f.natsPubSubByProviderID, - kafkaPubSubByProviderID: f.kafkaPubSubByProviderID, - } -} - -func (f *Factory[T]) Context() context.Context { - return f.executionContext -} - -func (f *Factory[T]) UpstreamSchema(_ plan.DataSourceConfiguration[T]) (*ast.Document, bool) { - return nil, false -} - -func (f *Factory[T]) PlanningBehavior() plan.DataSourcePlanningBehavior { - return plan.DataSourcePlanningBehavior{ - MergeAliasedRootNodes: false, - OverrideFieldPathFromAlias: false, - AllowPlanningTypeName: true, - } -} - -func buildEventDataBytes(ref int, visitor *plan.Visitor, variables *resolve.Variables) ([]byte, error) { - // Collect the field arguments for fetch based operations - fieldArgs := visitor.Operation.FieldArguments(ref) - var dataBuffer bytes.Buffer - dataBuffer.WriteByte('{') - for i, arg := range fieldArgs { - if i > 0 { - dataBuffer.WriteByte(',') - } - argValue := visitor.Operation.ArgumentValue(arg) - variableName := visitor.Operation.VariableValueNameBytes(argValue.Ref) - contextVariable := &resolve.ContextVariable{ - Path: []string{string(variableName)}, - Renderer: resolve.NewJSONVariableRenderer(), - } - variablePlaceHolder, _ := variables.AddVariable(contextVariable) - argumentName := visitor.Operation.ArgumentNameString(arg) - escapedKey, err := json.Marshal(argumentName) - if err != nil { - return nil, err - } - dataBuffer.Write(escapedKey) - dataBuffer.WriteByte(':') - dataBuffer.WriteString(variablePlaceHolder) - } - dataBuffer.WriteByte('}') - return dataBuffer.Bytes(), nil -} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go deleted file mode 100644 index 28a37df33b..0000000000 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go +++ /dev/null @@ -1,669 +0,0 @@ -package pubsub_datasource - -import ( - "bytes" - "context" - "errors" - "io" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/wundergraph/astjson" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasourcetesting" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafeparser" -) - -type testPubsub struct { -} - -func (t *testPubsub) Subscribe(_ context.Context, _ NatsSubscriptionEventConfiguration, _ resolve.SubscriptionUpdater) error { - return errors.New("not implemented") -} -func (t *testPubsub) Publish(_ context.Context, _ NatsPublishAndRequestEventConfiguration) error { - return errors.New("not implemented") -} - -func (t *testPubsub) Request(_ context.Context, _ NatsPublishAndRequestEventConfiguration, _ io.Writer) error { - return errors.New("not implemented") -} - -func TestPubSub(t *testing.T) { - factory := &Factory[Configuration]{ - natsPubSubByProviderID: map[string]NatsPubSub{"default": &testPubsub{}}, - } - - const schema = ` - type Query { - helloQuery(userKey: UserKey!): User! @edfs__natsRequest(subject: "tenants.{{ args.userKey.tenantId }}.users.{{ args.userKey.id }}") - } - - type Mutation { - helloMutation(userKey: UserKey!): edfs__PublishResult! @edfs__natsPublish(subject: "tenants.{{ args.userKey.tenantId }}.users.{{ args.userKey.id }}") - } - - type Subscription { - helloSubscription(userKey: UserKey!): User! @edfs__natsSubscribe(subjects: ["tenants.{{ args.userKey.tenantId }}.users.{{ args.userKey.id }}"]) - subscriptionWithMultipleSubjects(userKeyOne: UserKey!, userKeyTwo: UserKey!): User! @edfs__natsSubscribe(subjects: ["tenantsOne.{{ args.userKeyOne.tenantId }}.users.{{ args.userKeyOne.id }}", "tenantsTwo.{{ args.userKeyTwo.tenantId }}.users.{{ args.userKeyTwo.id }}"]) - subscriptionWithStaticValues: User! @edfs__natsSubscribe(subjects: ["tenants.1.users.1"]) - subscriptionWithArgTemplateAndStaticValue(nestedUserKey: NestedUserKey!): User! @edfs__natsSubscribe(subjects: ["tenants.1.users.{{ args.nestedUserKey.user.id }}"]) - } - - type User @key(fields: "id tenant { id }") { - id: Int! @external - tenant: Tenant! @external - } - - type Tenant { - id: Int! @external - } - - input UserKey { - id: Int! - tenantId: Int! - } - - input NestedUserKey { - user: UserInput! - tenant: TenantInput! - } - - input UserInput { - id: Int! - } - - input TenantInput { - id: Int! - } - - type edfs__PublishResult { - success: Boolean! - } - - input edfs__NatsStreamConfiguration { - consumerName: String! - streamName: String! - } - ` - - dataSourceCustomConfig := Configuration{ - Events: []EventConfiguration{ - { - Metadata: &EventMetadata{ - FieldName: "helloQuery", - ProviderID: "default", - Type: EventTypeRequest, - TypeName: "Query", - }, - Configuration: &NatsEventConfiguration{ - Subjects: []string{"tenants.{{ args.userKey.tenantId }}.users.{{ args.userKey.id }}"}, - }, - }, - { - Metadata: &EventMetadata{ - FieldName: "helloMutation", - ProviderID: "default", - Type: EventTypePublish, - TypeName: "Mutation", - }, - Configuration: &NatsEventConfiguration{ - Subjects: []string{"tenants.{{ args.userKey.tenantId }}.users.{{ args.userKey.id }}"}, - }, - }, - { - Metadata: &EventMetadata{ - FieldName: "helloSubscription", - ProviderID: "default", - Type: EventTypeSubscribe, - TypeName: "Subscription", - }, - Configuration: &NatsEventConfiguration{ - Subjects: []string{"tenants.{{ args.userKey.tenantId }}.users.{{ args.userKey.id }}"}, - }, - }, - { - Metadata: &EventMetadata{ - FieldName: "subscriptionWithMultipleSubjects", - ProviderID: "default", - Type: EventTypeSubscribe, - TypeName: "Subscription", - }, - Configuration: &NatsEventConfiguration{ - Subjects: []string{"tenantsOne.{{ args.userKeyOne.tenantId }}.users.{{ args.userKeyOne.id }}", "tenantsTwo.{{ args.userKeyTwo.tenantId }}.users.{{ args.userKeyTwo.id }}"}, - }, - }, - { - Metadata: &EventMetadata{ - FieldName: "subscriptionWithStaticValues", - ProviderID: "default", - Type: EventTypeSubscribe, - TypeName: "Subscription", - }, - Configuration: &NatsEventConfiguration{ - Subjects: []string{"tenants.1.users.1"}, - }, - }, - { - Metadata: &EventMetadata{ - FieldName: "subscriptionWithArgTemplateAndStaticValue", - ProviderID: "default", - Type: EventTypeSubscribe, - TypeName: "Subscription", - }, - Configuration: &NatsEventConfiguration{ - Subjects: []string{"tenants.1.users.{{ args.nestedUserKey.user.id }}"}, - }, - }, - }, - } - - dataSourceConfiguration, err := plan.NewDataSourceConfiguration[Configuration]( - "test", - factory, - &plan.DataSourceMetadata{ - RootNodes: []plan.TypeField{ - { - TypeName: "Query", - FieldNames: []string{"helloQuery"}, - }, - { - TypeName: "Mutation", - FieldNames: []string{"helloMutation"}, - }, - { - TypeName: "Subscription", - FieldNames: []string{"helloSubscription"}, - }, - { - TypeName: "Subscription", - FieldNames: []string{"subscriptionWithMultipleSubjects"}, - }, - { - TypeName: "Subscription", - FieldNames: []string{"subscriptionWithStaticValues"}, - }, - { - TypeName: "Subscription", - FieldNames: []string{"subscriptionWithArgTemplateAndStaticValue"}, - }, - }, - ChildNodes: []plan.TypeField{ - // Entities are child nodes in pubsub datasources - { - TypeName: "User", - FieldNames: []string{"id", "tenant"}, - }, - { - TypeName: "Tenant", - FieldNames: []string{"id"}, - }, - { - TypeName: "edfs__PublishResult", - FieldNames: []string{"success"}, - }, - }, - }, - dataSourceCustomConfig, - ) - require.NoError(t, err) - - planConfig := plan.Configuration{ - DataSources: []plan.DataSource{ - dataSourceConfiguration, - }, - Fields: []plan.FieldConfiguration{ - { - TypeName: "Query", - FieldName: "helloQuery", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "userKey", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - { - TypeName: "Mutation", - FieldName: "helloMutation", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "userKey", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - { - TypeName: "Subscription", - FieldName: "helloSubscription", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "userKey", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - { - TypeName: "Subscription", - FieldName: "subscriptionWithMultipleSubjects", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "userKeyOne", - SourceType: plan.FieldArgumentSource, - }, - { - Name: "userKeyTwo", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - { - TypeName: "Subscription", - FieldName: "subscriptionWithArgTemplateAndStaticValue", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "nestedUserKey", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - }, - DisableResolveFieldPositions: true, - } - - t.Run("query", func(t *testing.T) { - const operation = "query HelloQuery { helloQuery(userKey:{id:42,tenantId:3}) { id } }" - const operationName = `HelloQuery` - expect := &plan.SynchronousResponsePlan{ - Response: &resolve.GraphQLResponse{ - RawFetches: []*resolve.FetchItem{ - { - Fetch: &resolve.SingleFetch{ - FetchConfiguration: resolve.FetchConfiguration{ - Input: `{"subject":"tenants.$$0$$.users.$$1$$", "data": {"userKey":$$2$$}, "providerId":"default"}`, - Variables: resolve.Variables{ - &resolve.ContextVariable{ - Path: []string{"a", "tenantId"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"a", "id"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"a"}, - Renderer: resolve.NewJSONVariableRenderer(), - }, - }, - DataSource: &NatsRequestDataSource{ - pubSub: &testPubsub{}, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{"helloQuery"}, - }, - }, - DataSourceIdentifier: []byte("pubsub_datasource.NatsRequestDataSource"), - }, - }, - }, - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("helloQuery"), - Value: &resolve.Object{ - Path: []string{"helloQuery"}, - Nullable: false, - PossibleTypes: map[string]struct{}{ - "User": {}, - }, - TypeName: "User", - Fields: []*resolve.Field{ - { - Name: []byte("id"), - Value: &resolve.Integer{ - Path: []string{"id"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, - }, - } - datasourcetesting.RunTest(schema, operation, operationName, expect, planConfig)(t) - }) - - t.Run("mutation", func(t *testing.T) { - const operation = "mutation HelloMutation { helloMutation(userKey:{id:42,tenantId:3}) { success } }" - const operationName = `HelloMutation` - expect := &plan.SynchronousResponsePlan{ - Response: &resolve.GraphQLResponse{ - RawFetches: []*resolve.FetchItem{ - { - Fetch: &resolve.SingleFetch{ - FetchConfiguration: resolve.FetchConfiguration{ - Input: `{"subject":"tenants.$$0$$.users.$$1$$", "data": {"userKey":$$2$$}, "providerId":"default"}`, - Variables: resolve.Variables{ - &resolve.ContextVariable{ - Path: []string{"a", "tenantId"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"a", "id"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"a"}, - Renderer: resolve.NewJSONVariableRenderer(), - }, - }, - DataSource: &NatsPublishDataSource{ - pubSub: &testPubsub{}, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{"helloMutation"}, - }, - }, - DataSourceIdentifier: []byte("pubsub_datasource.NatsPublishDataSource"), - }, - }, - }, - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("helloMutation"), - Value: &resolve.Object{ - Path: []string{"helloMutation"}, - Nullable: false, - PossibleTypes: map[string]struct{}{ - "edfs__PublishResult": {}, - }, - TypeName: "edfs__PublishResult", - Fields: []*resolve.Field{ - { - Name: []byte("success"), - Value: &resolve.Boolean{ - Path: []string{"success"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, - }, - } - datasourcetesting.RunTest(schema, operation, operationName, expect, planConfig)(t) - }) - - t.Run("subscription", func(t *testing.T) { - const operation = "subscription HelloSubscription { helloSubscription(userKey:{id:42,tenantId:3}) { id } }" - const operationName = `HelloSubscription` - expect := &plan.SubscriptionResponsePlan{ - Response: &resolve.GraphQLSubscription{ - Trigger: resolve.GraphQLSubscriptionTrigger{ - Input: []byte(`{"providerId":"default","subjects":["tenants.$$0$$.users.$$1$$"]}`), - Variables: resolve.Variables{ - &resolve.ContextVariable{ - Path: []string{"a", "tenantId"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"a", "id"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - }, - Source: &NatsSubscriptionSource{ - pubSub: &testPubsub{}, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{"helloSubscription"}, - }, - }, - Response: &resolve.GraphQLResponse{ - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("helloSubscription"), - Value: &resolve.Object{ - Path: []string{"helloSubscription"}, - Nullable: false, - PossibleTypes: map[string]struct{}{ - "User": {}, - }, - TypeName: "User", - Fields: []*resolve.Field{ - { - Name: []byte("id"), - Value: &resolve.Integer{ - Path: []string{"id"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - datasourcetesting.RunTest(schema, operation, operationName, expect, planConfig)(t) - }) - - t.Run("subscription with multiple subjects", func(t *testing.T) { - const operation = "subscription SubscriptionWithMultipleSubjects { subscriptionWithMultipleSubjects(userKeyOne:{id:42,tenantId:3},userKeyTwo:{id:24,tenantId:99}) { id } }" - const operationName = `SubscriptionWithMultipleSubjects` - expect := &plan.SubscriptionResponsePlan{ - Response: &resolve.GraphQLSubscription{ - Trigger: resolve.GraphQLSubscriptionTrigger{ - Input: []byte(`{"providerId":"default","subjects":["tenantsOne.$$0$$.users.$$1$$","tenantsTwo.$$2$$.users.$$3$$"]}`), - Variables: resolve.Variables{ - &resolve.ContextVariable{ - Path: []string{"a", "tenantId"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"a", "id"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"b", "tenantId"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - &resolve.ContextVariable{ - Path: []string{"b", "id"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - }, - Source: &NatsSubscriptionSource{ - pubSub: &testPubsub{}, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{"subscriptionWithMultipleSubjects"}, - }, - }, - Response: &resolve.GraphQLResponse{ - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("subscriptionWithMultipleSubjects"), - Value: &resolve.Object{ - Path: []string{"subscriptionWithMultipleSubjects"}, - Nullable: false, - PossibleTypes: map[string]struct{}{ - "User": {}, - }, - TypeName: "User", - Fields: []*resolve.Field{ - { - Name: []byte("id"), - Value: &resolve.Integer{ - Path: []string{"id"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - datasourcetesting.RunTest(schema, operation, operationName, expect, planConfig)(t) - }) - - t.Run("subscription with only static values", func(t *testing.T) { - const operation = "subscription SubscriptionWithStaticValues { subscriptionWithStaticValues { id } }" - const operationName = `SubscriptionWithStaticValues` - expect := &plan.SubscriptionResponsePlan{ - Response: &resolve.GraphQLSubscription{ - Trigger: resolve.GraphQLSubscriptionTrigger{ - Input: []byte(`{"providerId":"default","subjects":["tenants.1.users.1"]}`), - Source: &NatsSubscriptionSource{ - pubSub: &testPubsub{}, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{"subscriptionWithStaticValues"}, - }, - }, - Response: &resolve.GraphQLResponse{ - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("subscriptionWithStaticValues"), - Value: &resolve.Object{ - Path: []string{"subscriptionWithStaticValues"}, - Nullable: false, - PossibleTypes: map[string]struct{}{ - "User": {}, - }, - TypeName: "User", - Fields: []*resolve.Field{ - { - Name: []byte("id"), - Value: &resolve.Integer{ - Path: []string{"id"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - datasourcetesting.RunTest(schema, operation, operationName, expect, planConfig)(t) - }) - - t.Run("subscription with deeply nested argument and static value", func(t *testing.T) { - const operation = "subscription SubscriptionWithArgTemplateAndStaticValue { subscriptionWithArgTemplateAndStaticValue(nestedUserKey: { user: { id: 44, tenantId: 2 } }) { id } }" - const operationName = `SubscriptionWithArgTemplateAndStaticValue` - expect := &plan.SubscriptionResponsePlan{ - Response: &resolve.GraphQLSubscription{ - Trigger: resolve.GraphQLSubscriptionTrigger{ - Input: []byte(`{"providerId":"default","subjects":["tenants.1.users.$$0$$"]}`), - Variables: resolve.Variables{ - &resolve.ContextVariable{ - Path: []string{"a", "user", "id"}, - Renderer: resolve.NewPlainVariableRenderer(), - }, - }, - Source: &NatsSubscriptionSource{ - pubSub: &testPubsub{}, - }, - PostProcessing: resolve.PostProcessingConfiguration{ - MergePath: []string{"subscriptionWithArgTemplateAndStaticValue"}, - }, - }, - Response: &resolve.GraphQLResponse{ - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("subscriptionWithArgTemplateAndStaticValue"), - Value: &resolve.Object{ - Path: []string{"subscriptionWithArgTemplateAndStaticValue"}, - Nullable: false, - PossibleTypes: map[string]struct{}{ - "User": {}, - }, - TypeName: "User", - Fields: []*resolve.Field{ - { - Name: []byte("id"), - Value: &resolve.Integer{ - Path: []string{"id"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - datasourcetesting.RunTest(schema, operation, operationName, expect, planConfig)(t) - }) -} - -func TestBuildEventDataBytes(t *testing.T) { - t.Run("check string serialization", func(t *testing.T) { - const operation = "mutation HelloMutation($id: ID!) { helloMutation(userKey:{id:$id,tenantId:3}) { success } }" - op := unsafeparser.ParseGraphqlDocumentString(operation) - var vars resolve.Variables - visitor := plan.Visitor{ - Operation: &op, - } - _, err := buildEventDataBytes(1, &visitor, &vars) - require.NoError(t, err) - require.Len(t, vars, 1) - - template := resolve.InputTemplate{ - Segments: []resolve.TemplateSegment{ - vars[0].TemplateSegment(), - }, - } - ctx := &resolve.Context{ - Variables: astjson.MustParseBytes([]byte(`{"id":"asdf"}`)), - } - buf := &bytes.Buffer{} - err = template.Render(ctx, nil, buf) - require.NoError(t, err) - require.Equal(t, `"asdf"`, buf.String()) - }) - - t.Run("check int serialization", func(t *testing.T) { - const operation = "mutation HelloMutation($id: Int!) { helloMutation(userKey:{id:$id,tenantId:3}) { success } }" - op := unsafeparser.ParseGraphqlDocumentString(operation) - var vars resolve.Variables - visitor := plan.Visitor{ - Operation: &op, - } - _, err := buildEventDataBytes(1, &visitor, &vars) - require.NoError(t, err) - require.Len(t, vars, 1) - - template := resolve.InputTemplate{ - Segments: []resolve.TemplateSegment{ - vars[0].TemplateSegment(), - }, - } - ctx := &resolve.Context{ - Variables: astjson.MustParseBytes([]byte(`{"id":5}`)), - } - buf := &bytes.Buffer{} - err = template.Render(ctx, nil, buf) - require.NoError(t, err) - require.Equal(t, `5`, buf.String()) - }) -} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go deleted file mode 100644 index cc562b803e..0000000000 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ /dev/null @@ -1,88 +0,0 @@ -package pubsub_datasource - -import ( - "bytes" - "context" - "encoding/json" - "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -type KafkaEventConfiguration struct { - Topics []string `json:"topics"` -} - -type KafkaConnector interface { - New(ctx context.Context) KafkaPubSub -} - -// KafkaPubSub describe the interface that implements the primitive operations for pubsub -type KafkaPubSub interface { - // Subscribe starts listening on the given subjects and sends the received messages to the given next channel - Subscribe(ctx context.Context, config KafkaSubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error - // Publish sends the given data to the given subject - Publish(ctx context.Context, config KafkaPublishEventConfiguration) error -} - -type KafkaSubscriptionSource struct { - pubSub KafkaPubSub -} - -func (s *KafkaSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration KafkaSubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { - return err - } - - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) -} - -type KafkaPublishDataSource struct { - pubSub KafkaPubSub -} - -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration KafkaPublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { - return err - } - - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err - } - _, err = io.WriteString(out, `{"success": true}`) - return err -} - -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - panic("not implemented") -} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go deleted file mode 100644 index 31cb6d4154..0000000000 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ /dev/null @@ -1,116 +0,0 @@ -package pubsub_datasource - -import ( - "bytes" - "context" - "encoding/json" - "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -type NatsStreamConfiguration struct { - Consumer string `json:"consumer"` - ConsumerInactiveThreshold int32 `json:"consumerInactiveThreshold"` - StreamName string `json:"streamName"` -} - -type NatsEventConfiguration struct { - StreamConfiguration *NatsStreamConfiguration `json:"streamConfiguration,omitempty"` - Subjects []string `json:"subjects"` -} - -type NatsConnector interface { - New(ctx context.Context) NatsPubSub -} - -// NatsPubSub describe the interface that implements the primitive operations for pubsub -type NatsPubSub interface { - // Subscribe starts listening on the given subjects and sends the received messages to the given next channel - Subscribe(ctx context.Context, event NatsSubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error - // Publish sends the given data to the given subject - Publish(ctx context.Context, event NatsPublishAndRequestEventConfiguration) error - // Request sends a request on the given subject and writes the response to the given writer - Request(ctx context.Context, event NatsPublishAndRequestEventConfiguration, w io.Writer) error -} - -type NatsSubscriptionSource struct { - pubSub NatsPubSub -} - -func (s *NatsSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration NatsSubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { - return err - } - - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) -} - -type NatsPublishDataSource struct { - pubSub NatsPubSub -} - -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { - return err - } - - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err - } - - _, err = io.WriteString(out, `{"success": true}`) - return err -} - -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { - panic("not implemented") -} - -type NatsRequestDataSource struct { - pubSub NatsPubSub -} - -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var subscriptionConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { - return err - } - - return s.pubSub.Request(ctx, subscriptionConfiguration, out) -} - -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { - panic("not implemented") -} diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index c679d7693a..27b421b61c 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -26,3 +26,15 @@ type AsyncSubscriptionDataSource interface { AsyncStop(id uint64) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) } + +// HookableSubscriptionDataSource is a hookable interface for subscription data sources. +// It is used to call a function when a subscription is started. +// This is useful for data sources that need to do some work when a subscription is started, +// e.g. to establish a connection to the data source or to emit updates to the client. +// The function is called with the context and the input of the subscription. +// The function is called before the subscription is started and can be used to emit updates to the client. +type HookableSubscriptionDataSource interface { + // SubscriptionOnStart is called when a new subscription is created + // If an error is returned, the error is propagated to the client. + SubscriptionOnStart(ctx StartupHookContext, input []byte) (err error) +} diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 14d8ad4b52..82558795cb 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -297,6 +297,7 @@ type trigger struct { subscriptions map[*Context]*sub // initialized is set to true when the trigger is started and initialized initialized bool + updater *subscriptionUpdater } // workItem is used to encapsulate a function that needs to be @@ -507,6 +508,8 @@ func (r *Resolver) handleEvent(event subscriptionEvent) { r.handleCompleteSubscription(event.id) case subscriptionEventKindRemoveClient: r.handleRemoveClient(event.id.ConnectionID) + case subscriptionEventKindUpdateSubscription: + r.handleUpdateSubscription(event.triggerID, event.data, event.id) case subscriptionEventKindTriggerUpdate: r.handleTriggerUpdate(event.triggerID, event.data) case subscriptionEventKindTriggerComplete: @@ -581,6 +584,36 @@ func (r *Resolver) handleTriggerComplete(triggerID uint64) { r.completeTrigger(triggerID) } +type StartupHookContext struct { + Context context.Context + Updater func(data []byte) +} + +func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscriptionUpdater) error { + hook, ok := add.resolve.Trigger.Source.(HookableSubscriptionDataSource) + if ok { + hookCtx := StartupHookContext{ + Context: add.ctx.Context(), + Updater: func(data []byte) { + // Writing on the updater channel is safe but has to happen outside of the event loop + // to respect order and not block the event loop + updater.UpdateSubscription(add.id, data) + }, + } + + err := hook.SubscriptionOnStart(hookCtx, add.input) + if err != nil { + if r.options.Debug { + fmt.Printf("resolver:trigger:subscription:startup:failed:%d\n", add.id.SubscriptionID) + } + r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) + _ = r.AsyncUnsubscribeSubscription(add.id) + return err + } + } + return nil +} + func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) { var ( err error @@ -608,13 +641,23 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) trig, ok := r.triggers[triggerID] if ok { - trig.subscriptions[add.ctx] = s if r.reporter != nil { r.reporter.SubscriptionCountInc(1) } if r.options.Debug { fmt.Printf("resolver:trigger:subscription:added:%d:%d\n", triggerID, add.id.SubscriptionID) } + // Execute the startup hooks in a separate goroutine to avoid blocking the event loop + s.workChan <- workItem{ + fn: func() { + _ = r.executeStartupHooks(add, trig.updater) + // if the startup hooks return an error, we don't have to do anything else + }, + final: false, + } + // After the startup hooks are executed, we can add the subscription to the subscriptions registry + // so that it can start receive events + trig.subscriptions[add.ctx] = s return } @@ -633,6 +676,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) id: triggerID, subscriptions: make(map[*Context]*sub), cancel: cancel, + updater: updater, } r.triggers[triggerID] = trig trig.subscriptions[add.ctx] = s @@ -655,10 +699,17 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } + + // This is blocking so the startup hook can decide if a subscription should be started or not by returning an error + err = r.executeStartupHooks(add, trig.updater) + if err != nil { + return + } + if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.input, updater) + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.input, trig.updater) } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) + err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, trig.updater) } if err != nil { if r.options.Debug { @@ -779,32 +830,54 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { } for c, s := range trig.subscriptions { - c, s := c, s - if err := c.ctx.Err(); err != nil { - continue // no need to schedule an event update when the client already disconnected - } - skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf) - if err != nil { - r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer) - continue - } - if skip { + r.sendUpdateToSubscription(id, data, c, s) + } +} + +func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifier SubscriptionIdentifier) { + trig, ok := r.triggers[id] + if !ok { + return + } + + if r.options.Debug { + fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID) + } + + for c, s := range trig.subscriptions { + if s.id != subIdentifier { continue } + r.sendUpdateToSubscription(id, data, c, s) + break + } +} - fn := func() { - r.executeSubscriptionUpdate(c, s, data) - } +func (r *Resolver) sendUpdateToSubscription(id uint64, data []byte, c *Context, s *sub) { + if err := c.ctx.Err(); err != nil { + return // no need to schedule an event update when the client already disconnected + } + skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf) + if err != nil { + r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer) + return + } + if skip { + return + } - select { - case <-r.ctx.Done(): - // Skip sending all events if the resolver is shutting down - return - case <-c.ctx.Done(): - // Skip sending the event if the client disconnected - case s.workChan <- workItem{fn, false}: - // Send the event to the subscription worker - } + fn := func() { + r.executeSubscriptionUpdate(c, s, data) + } + + select { + case <-r.ctx.Done(): + // Skip sending all events if the resolver is shutting down + return + case <-c.ctx.Done(): + // Skip sending the event if the client disconnected + case s.workChan <- workItem{fn, false}: + // Send the event to the subscription worker } } @@ -1224,6 +1297,24 @@ func (s *subscriptionUpdater) Update(data []byte) { } } +func (s *subscriptionUpdater) UpdateSubscription(id SubscriptionIdentifier, data []byte) { + if s.debug { + fmt.Printf("resolver:subscription_updater:update:%d\n", s.triggerID) + } + + select { + case <-s.ctx.Done(): + // Skip sending events if trigger is already done + return + case s.ch <- subscriptionEvent{ + triggerID: s.triggerID, + kind: subscriptionEventKindUpdateSubscription, + data: data, + id: id, + }: + } +} + func (s *subscriptionUpdater) Complete() { if s.debug { fmt.Printf("resolver:subscription_updater:complete:%d\n", s.triggerID) @@ -1299,6 +1390,7 @@ const ( subscriptionEventKindRemoveClient subscriptionEventKindTriggerInitialized subscriptionEventKindTriggerClose + subscriptionEventKindUpdateSubscription ) type SubscriptionUpdater interface { diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 8e15ff98a9..c8045d35fe 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -4794,11 +4794,12 @@ func (s *SubscriptionRecorder) Messages() []string { return s.messages } -func createFakeStream(messageFunc messageFunc, delay time.Duration, onStart func(input []byte)) *_fakeStream { +func createFakeStream(messageFunc messageFunc, delay time.Duration, onStart func(input []byte), subscriptionOnStartFn func(ctx StartupHookContext, input []byte) (err error)) *_fakeStream { return &_fakeStream{ - messageFunc: messageFunc, - delay: delay, - onStart: onStart, + messageFunc: messageFunc, + delay: delay, + onStart: onStart, + subscriptionOnStartFn: subscriptionOnStartFn, } } @@ -4807,10 +4808,19 @@ type messageFunc func(counter int) (message string, done bool) var fakeStreamRequestId atomic.Int32 type _fakeStream struct { - messageFunc messageFunc - onStart func(input []byte) - delay time.Duration - isDone atomic.Bool + uniqueRequestFn func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) + messageFunc messageFunc + onStart func(input []byte) + delay time.Duration + isDone atomic.Bool + subscriptionOnStartFn func(ctx StartupHookContext, input []byte) (err error) +} + +func (f *_fakeStream) SubscriptionOnStart(ctx StartupHookContext, input []byte) (err error) { + if f.subscriptionOnStartFn == nil { + return nil + } + return f.subscriptionOnStartFn(ctx, input) } func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { @@ -4828,6 +4838,10 @@ func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { } func (f *_fakeStream) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { + if f.uniqueRequestFn != nil { + return f.uniqueRequestFn(ctx, input, xxh) + } + _, err = fmt.Fprint(xxh, fakeStreamRequestId.Add(1)) if err != nil { return @@ -5071,7 +5085,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return `{"errors":[{"message":"Validation error occurred","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`, true }, 0, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5109,7 +5123,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2 }, 1*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5142,7 +5156,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 1 }, 1*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5200,7 +5214,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2 }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }","extensions":{"foo":"bar"}}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5228,7 +5242,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 2 }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"},"initial_payload":{"hello":"world"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5256,7 +5270,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false }, time.Millisecond*10, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5281,7 +5295,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false }, time.Millisecond*10, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5306,7 +5320,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setup(c, fakeStream) @@ -5335,7 +5349,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { countryUpdated { name time { local } } }"}}`, string(input)) - }) + }, nil) resolver, plan, recorder, id := setupWithAdditionalDataLoad(c, fakeStream) @@ -5364,7 +5378,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, _, _ := setup(c, fakeStream) @@ -5418,7 +5432,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) - }) + }, nil) resolver, plan, _, id := setup(c, fakeStream) recorder := &SubscriptionRecorder{ @@ -5441,6 +5455,384 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { recorder.AwaitClosed(t, defaultTimeout) fakeStream.AwaitIsDone(t, defaultTimeout) }) + + t.Run("should call SubscriptionOnStart hook", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + called := make(chan bool, 1) + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }, func(ctx StartupHookContext, input []byte) (err error) { + called <- true + return nil + }) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx := &Context{ + ctx: context.Background(), + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + + select { + case <-called: + t.Log("SubscriptionOnStart hook was called") + case <-time.After(defaultTimeout): + t.Fatal("SubscriptionOnStart hook was not called") + } + + recorder.AwaitComplete(t, defaultTimeout) + }) + + t.Run("SubscriptionOnStart ctx has a working subscription updater", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }, func(ctx StartupHookContext, input []byte) (err error) { + ctx.Updater([]byte(`{"data":{"counter":1000}}`)) + return nil + }) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + + recorder.AwaitComplete(t, defaultTimeout) + assert.Equal(t, 2, len(recorder.Messages())) + assert.Equal(t, `{"data":{"counter":1000}}`, recorder.Messages()[0]) + assert.Equal(t, `{"data":{"counter":0}}`, recorder.Messages()[1]) + }) + + t.Run("SubscriptionOnStart ctx updater only updates the right subscription", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + executed := atomic.Bool{} + subsStarted := sync.WaitGroup{} + subsStarted.Add(2) + + id2 := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 2, + } + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }, func(ctx StartupHookContext, input []byte) (err error) { + if executed.Load() { + return + } + executed.Store(true) + ctx.Updater([]byte(`{"data":{"counter":1000}}`)) + return nil + }) + fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { + return nil + } + + resolver, plan, recorder, id := setup(c, fakeStream) + + recorder2 := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + recorder2.complete.Store(false) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + ctx2 := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + subsStarted.Done() + + err2 := resolver.AsyncResolveGraphQLSubscription(ctx2, plan, recorder2, id2) + assert.NoError(t, err2) + subsStarted.Done() + + recorder.AwaitComplete(t, defaultTimeout) + recorder2.AwaitComplete(t, defaultTimeout) + + recorders := []*SubscriptionRecorder{recorder, recorder2} + + recorderWith1Message := false + recorderWith2Messages := false + + for _, r := range recorders { + if len(r.Messages()) == 2 { + recorderWith2Messages = true + assert.Equal(t, `{"data":{"counter":1000}}`, r.Messages()[0]) + assert.Equal(t, `{"data":{"counter":0}}`, r.Messages()[1]) + } + if len(r.Messages()) == 1 { + recorderWith1Message = true + assert.Equal(t, `{"data":{"counter":0}}`, r.Messages()[0]) + } + } + + assert.True(t, recorderWith1Message) + assert.True(t, recorderWith2Messages) + }) + + t.Run("SubscriptionOnStart ctx updater on multiple subscriptions with same trigger works", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + subsStarted := sync.WaitGroup{} + subsStarted.Add(2) + + id2 := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 2, + } + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }, func(ctx StartupHookContext, input []byte) (err error) { + ctx.Updater([]byte(`{"data":{"counter":1000}}`)) + return nil + }) + fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { + _, err = xxh.WriteString("unique") + return + } + + resolver, plan, recorder, id := setup(c, fakeStream) + + recorder2 := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + recorder2.complete.Store(false) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + ctx2 := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + subsStarted.Done() + + err2 := resolver.AsyncResolveGraphQLSubscription(ctx2, plan, recorder2, id2) + assert.NoError(t, err2) + subsStarted.Done() + + recorder.AwaitComplete(t, defaultTimeout) + recorder2.AwaitComplete(t, defaultTimeout) + + recorders := []*SubscriptionRecorder{recorder, recorder2} + + for _, r := range recorders { + if len(r.Messages()) == 2 { + assert.Equal(t, `{"data":{"counter":1000}}`, r.Messages()[0]) + assert.Equal(t, `{"data":{"counter":0}}`, r.Messages()[1]) + } else { + assert.Fail(t, "should not be here") + } + } + }) + + t.Run("SubscriptionOnStart can send a lot of updates without blocking", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + workChanBufferSize := 10000 + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }, func(ctx StartupHookContext, input []byte) (err error) { + for i := 0; i < workChanBufferSize+1; i++ { + ctx.Updater([]byte(fmt.Sprintf(`{"data":{"counter":%d}}`, i+100))) + } + return nil + }) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + + recorder.AwaitComplete(t, defaultTimeout) + assert.Equal(t, workChanBufferSize+2, len(recorder.Messages())) + for i := 0; i < workChanBufferSize; i++ { + assert.Equal(t, fmt.Sprintf(`{"data":{"counter":%d}}`, i+100), recorder.Messages()[i]) + } + assert.Equal(t, `{"data":{"counter":0}}`, recorder.Messages()[workChanBufferSize+1]) + }) + + t.Run("SubscriptionOnStart can send a lot of updates in a go routine while updates are coming from other sources", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + messagesToSendFromHook := int32(100) + messagesDroppedFromHook := &atomic.Int32{} + messagesToSendFromOtherSources := int32(100) + + firstMessageArrived := make(chan bool, 1) + hookCompleted := make(chan bool, 1) + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + if counter == 0 { + select { + case firstMessageArrived <- true: + default: + } + } + if counter == int(messagesToSendFromOtherSources)-1 { + select { + case hookCompleted <- true: + case <-time.After(defaultTimeout): + } + } + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == int(messagesToSendFromOtherSources)-1 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }, func(ctx StartupHookContext, input []byte) (err error) { + // send the first update immediately + ctx.Updater([]byte(fmt.Sprintf(`{"data":{"counter":%d}}`, 0+20000))) + + // start a go routine to send the updates after the source started emitting messages + go func() { + // Wait for the first message to arrive before sending updates + select { + case <-firstMessageArrived: + for i := 1; i < int(messagesToSendFromHook); i++ { + ctx.Updater([]byte(fmt.Sprintf(`{"data":{"counter":%d}}`, i+20000))) + } + hookCompleted <- true + case <-time.After(defaultTimeout): + // if the first message did not arrive, do not send any updates + return + } + }() + + return nil + }) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: false, + }, + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + + recorder.AwaitComplete(t, defaultTimeout*2) + + var messagesHeartbeat int32 + for _, m := range recorder.Messages() { + if m == "{}" { + messagesHeartbeat++ + } + } + assert.Equal(t, int32(messagesToSendFromHook+messagesToSendFromOtherSources-messagesDroppedFromHook.Load()+messagesHeartbeat), int32(len(recorder.Messages()))) + assert.Equal(t, `{"data":{"counter":20000}}`, recorder.Messages()[0]) + }) + + t.Run("it is possible to have two subscriptions to the same trigger", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 100 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }, func(ctx StartupHookContext, input []byte) (err error) { + return nil + }) + fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { + _, err = xxh.WriteString("unique") + if err != nil { + return + } + _, err = xxh.Write(input) + return err + } + + resolver1, plan1, recorder1, id1 := setup(c, fakeStream) + _, _, recorder2, id2 := setup(c, fakeStream) + id2.ConnectionID = id1.ConnectionID + 1 + id2.SubscriptionID = id1.SubscriptionID + 1 + + ctx1 := &Context{ + ctx: context.Background(), + } + ctx2 := &Context{ + ctx: context.Background(), + } + + err1 := resolver1.AsyncResolveGraphQLSubscription(ctx1, plan1, recorder1, id1) + assert.NoError(t, err1) + + err2 := resolver1.AsyncResolveGraphQLSubscription(ctx2, plan1, recorder2, id2) + assert.NoError(t, err2) + + // complete is called only on the last recorder + recorder1.AwaitComplete(t, defaultTimeout) + require.Equal(t, 101, len(recorder1.Messages())) + assert.Equal(t, `{"data":{"counter":0}}`, recorder1.Messages()[0]) + assert.Equal(t, `{"data":{"counter":100}}`, recorder1.Messages()[100]) + + recorder2.AwaitComplete(t, defaultTimeout) + require.Equal(t, 101, len(recorder2.Messages())) + assert.Equal(t, `{"data":{"counter":0}}`, recorder2.Messages()[0]) + assert.Equal(t, `{"data":{"counter":100}}`, recorder2.Messages()[100]) + }) } func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { @@ -5513,7 +5905,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":2}`, true }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) - }) + }, nil) plan := &GraphQLSubscription{ Trigger: GraphQLSubscriptionTrigger{ @@ -5609,7 +6001,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":2}`, true }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) - }) + }, nil) plan := &GraphQLSubscription{ Trigger: GraphQLSubscriptionTrigger{ @@ -5703,7 +6095,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":4}`, true }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) - }) + }, nil) plan := &GraphQLSubscription{ Trigger: GraphQLSubscriptionTrigger{ @@ -5798,7 +6190,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":"x.4"}`, true }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) - }) + }, nil) plan := &GraphQLSubscription{ Trigger: GraphQLSubscriptionTrigger{ @@ -5897,7 +6289,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { return `{"id":"x.2"}`, true }, 100*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000"}`, string(input)) - }) + }, nil) plan := &GraphQLSubscription{ Trigger: GraphQLSubscriptionTrigger{