diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index a1658dc35c..21f62a9b53 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -2,7 +2,9 @@ package module_test import ( "errors" + "fmt" "net/http" + "sync/atomic" "testing" "time" @@ -196,7 +198,6 @@ func TestReceiveHook(t *testing.T) { cfg := config.Config{ Graph: config.Graph{}, - Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { @@ -353,7 +354,6 @@ func TestReceiveHook(t *testing.T) { cfg := config.Config{ Graph: config.Graph{}, - Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { @@ -518,4 +518,376 @@ func TestReceiveHook(t *testing.T) { xEnv.WaitForTriggerCount(0, Timeout) }) }) + + t.Run("Test error deduplication with multiple subscriptions", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return nil, errors.New("deduplicated error") + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.ErrorLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + + // Create 3 subscriptions that will all receive the same error + clients := make([]*graphql.SubscriptionClient, 3) + clientRunChs := make([]chan error, 3) + + for i := range 3 { + clients[i] = graphql.NewSubscriptionClient(surl) + clientRunChs[i] = make(chan error) + + subscriptionID, err := clients[i].Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + go func() { + clientRunChs[i] <- clients[i].Run() + }() + } + + // Wait for all subscriptions to be established + xEnv.WaitForSubscriptionCount(3, Timeout) + + // Produce a message that will trigger the error in all handlers + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Wait for all subscriptions to be closed due to the error + xEnv.WaitForSubscriptionCount(0, Timeout) + + // Verify all clients completed + for i := 0; i < 3; i++ { + testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { + require.NoError(t, err) + }, "client should have completed when server closed connection") + } + + xEnv.WaitForTriggerCount(0, Timeout) + + // Verify error deduplication: should see only one error log entry + errorLogs := xEnv.Observer().FilterMessage("some handlers have thrown an error") + assert.Len(t, errorLogs.All(), 1, "should have exactly one deduplicated error log entry") + + // Verify the error log contains the correct error message and count + if len(errorLogs.All()) > 0 { + logEntry := errorLogs.All()[0] + fields := logEntry.ContextMap() + + assert.Equal(t, "deduplicated error", fields["error"], "error message should match") + assert.Equal(t, int64(3), fields["amount_handlers"], "should count all 3 handlers that threw the error") + } + }) + }) + + t.Run("Test unique error messages are all logged", func(t *testing.T) { + t.Parallel() + + var errorCounter atomic.Int32 + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + count := errorCounter.Add(1) + return nil, fmt.Errorf("unique error %d", count) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.ErrorLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + + // Create 3 subscriptions that will each receive a unique error + clients := make([]*graphql.SubscriptionClient, 3) + clientRunChs := make([]chan error, 3) + + for i := range 3 { + clients[i] = graphql.NewSubscriptionClient(surl) + clientRunChs[i] = make(chan error) + + subscriptionID, err := clients[i].Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + go func() { + clientRunChs[i] <- clients[i].Run() + }() + } + + // Wait for all subscriptions to be established + xEnv.WaitForSubscriptionCount(3, Timeout) + + // Produce a message that will trigger a unique error in each handler + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Wait for all subscriptions to be closed due to the error + xEnv.WaitForSubscriptionCount(0, Timeout) + + // Verify all clients completed + for i := range 3 { + testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { + require.NoError(t, err) + }, "client should have completed when server closed connection") + } + + xEnv.WaitForTriggerCount(0, Timeout) + + // Verify no deduplication: should see three error log entries (one for each unique error) + errorLogs := xEnv.Observer().FilterMessage("some handlers have thrown an error") + assert.Len(t, errorLogs.All(), 3, "should have three separate error log entries for unique errors") + + // Verify each error log contains a unique error message and count of 1 + if len(errorLogs.All()) == 3 { + var errorMessages []string + for _, logEntry := range errorLogs.All() { + fields := logEntry.ContextMap() + errorMsg, ok := fields["error"].(string) + require.True(t, ok, "error field should be a string") + + // Check that error message is unique (starts with "unique error") + assert.Contains(t, errorMsg, "unique error", "error message should contain 'unique error'") + assert.NotContains(t, errorMessages, errorMsg, "each error message should be unique") + errorMessages = append(errorMessages, errorMsg) + + // Each unique error should have been thrown by exactly 1 handler + assert.Equal(t, int64(1), fields["amount_handlers"], "each unique error should have amount_handlers = 1") + } + + // Verify we got exactly 3 unique error messages + assert.Len(t, errorMessages, 3, "should have exactly 3 unique error messages") + } + }) + }) + + t.Run("Test concurrent handler execution works", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + maxConcurrent int + numSubscribers int + }{ + { + name: "1 concurrent handler", + maxConcurrent: 1, + numSubscribers: 5, + }, + { + name: "2 concurrent handlers", + maxConcurrent: 2, + numSubscribers: 10, + }, + { + name: "10 concurrent handlers", + maxConcurrent: 10, + numSubscribers: 20, + }, + { + name: "20 concurrent handlers", + maxConcurrent: 20, + numSubscribers: 40, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var ( + currentHandlers atomic.Int32 + maxCurrentHandlers atomic.Int32 + finishedHandlers atomic.Int32 + ) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + currentHandlers.Add(1) + + // wait for other handlers in the batch + for { + current := currentHandlers.Load() + max := maxCurrentHandlers.Load() + + if current > max { + maxCurrentHandlers.CompareAndSwap(max, current) + } + + if current >= int32(tc.maxConcurrent) { + // wait to see if the updater spawns too many concurrent handlers + deadline := time.Now().Add(300 * time.Millisecond) + for time.Now().Before(deadline) { + if currentHandlers.Load() > int32(tc.maxConcurrent) { + break + } + } + break + } + + // Let handlers continue if we never reach a batch size = tc.maxConcurrent + // because there are not enough remaining subscribers to be updated. + remainingSubs := tc.numSubscribers - int(finishedHandlers.Load()) + if remainingSubs < tc.maxConcurrent { + break + } + } + + currentHandlers.Add(-1) + finishedHandlers.Add(1) + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + core.WithSubscriptionHooks(config.SubscriptionHooksConfiguration{ + MaxConcurrentEventReceiveHandlers: tc.maxConcurrent, + }), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionQuery struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + + clients := make([]*graphql.SubscriptionClient, tc.numSubscribers) + clientRunChs := make([]chan error, tc.numSubscribers) + subscriptionArgsChs := make([]chan kafkaSubscriptionArgs, tc.numSubscribers) + + for i := range tc.numSubscribers { + clients[i] = graphql.NewSubscriptionClient(surl) + clientRunChs[i] = make(chan error) + subscriptionArgsChs[i] = make(chan kafkaSubscriptionArgs, 1) + + idx := i + subscriptionID, err := clients[i].Subscribe(&subscriptionQuery, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsChs[idx] <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + go func(i int) { + clientRunChs[i] <- clients[i].Run() + }(i) + } + + xEnv.WaitForSubscriptionCount(uint64(tc.numSubscribers), Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Collect events from all subscribers + for i := 0; i < tc.numSubscribers; i++ { + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsChs[i], func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + } + + // Close all clients + for i := 0; i < tc.numSubscribers; i++ { + require.NoError(t, clients[i].Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + } + + for i := range subscriptionArgsChs { + close(subscriptionArgsChs[i]) + } + + assert.Equal(t, int32(tc.maxConcurrent), maxCurrentHandlers.Load(), "amount of concurrent handlers not what was expected") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), tc.numSubscribers) + }) + }) + } + }) } diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index d7c72fe579..d155c6c5b7 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -501,9 +501,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod l.resolver.InstanceData().HostName, l.resolver.InstanceData().ListenAddress, pubsub_datasource.Hooks{ - SubscriptionOnStart: subscriptionOnStartFns, - OnReceiveEvents: onReceiveEventsFns, - OnPublishEvents: onPublishEventsFns, + SubscriptionOnStart: subscriptionOnStartFns, + OnReceiveEvents: onReceiveEventsFns, + OnPublishEvents: onPublishEventsFns, + MaxConcurrentOnReceiveHandlers: l.subscriptionHooks.maxConcurrentOnReceiveHooks, }, ) if err != nil { diff --git a/router/core/router.go b/router/core/router.go index 919ac49c29..1dfb038f20 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -244,6 +244,11 @@ func NewRouter(opts ...Option) (*Router, error) { r.metricConfig = rmetric.DefaultConfig(Version) } + // Default value for maxConcurrentOnReceiveHooks + if r.subscriptionHooks.maxConcurrentOnReceiveHooks == 0 { + r.subscriptionHooks.maxConcurrentOnReceiveHooks = 100 + } + if r.corsOptions == nil { r.corsOptions = CorsDefaultOptions() } @@ -2122,6 +2127,12 @@ func WithDemoMode(demoMode bool) Option { } } +func WithSubscriptionHooks(cfg config.SubscriptionHooksConfiguration) Option { + return func(r *Router) { + r.subscriptionHooks.maxConcurrentOnReceiveHooks = cfg.MaxConcurrentEventReceiveHandlers + } +} + type ProxyFunc func(req *http.Request) (*url.URL, error) func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDialer *TraceDialer, subgraph string) *http.Transport { diff --git a/router/core/router_config.go b/router/core/router_config.go index 3e282d3c65..ba24294104 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -27,9 +27,10 @@ import ( ) type subscriptionHooks struct { - onStart []func(ctx SubscriptionOnStartHandlerContext) error - onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) - onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onStart []func(ctx SubscriptionOnStartHandlerContext) error + onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + maxConcurrentOnReceiveHooks int } type Config struct { diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 01f1daaef1..3f5cc37d7e 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -3,6 +3,10 @@ package core import ( "context" "fmt" + "net/http" + "os" + "strings" + "github.com/KimMachineGun/automemlimit/memlimit" "github.com/dustin/go-humanize" "github.com/wundergraph/cosmo/router/pkg/authentication" @@ -13,9 +17,6 @@ import ( "go.uber.org/automaxprocs/maxprocs" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "net/http" - "os" - "strings" ) // newRouter creates a new router instance. @@ -251,6 +252,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config) []Option { WithMCP(config.MCP), WithPlugins(config.Plugins), WithDemoMode(config.DemoMode), + WithSubscriptionHooks(config.Events.SubscriptionHooks), } return options diff --git a/router/demo.config.yaml b/router/demo.config.yaml index 9a72e31de2..2a081e74be 100644 --- a/router/demo.config.yaml +++ b/router/demo.config.yaml @@ -19,4 +19,4 @@ events: redis: - id: my-redis urls: - - "redis://localhost:6379/2" \ No newline at end of file + - "redis://localhost:6379/2" diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 73e8f85e28..8d52b4c4a3 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -639,7 +639,12 @@ type EventProviders struct { } type EventsConfiguration struct { - Providers EventProviders `yaml:"providers,omitempty"` + Providers EventProviders `yaml:"providers,omitempty"` + SubscriptionHooks SubscriptionHooksConfiguration `yaml:"subscription_hooks,omitempty"` +} + +type SubscriptionHooksConfiguration struct { + MaxConcurrentEventReceiveHandlers int `yaml:"max_concurrent_event_receive_handlers" envDefault:"100"` } type Cluster struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 528e0e1ce7..4992e504a8 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2273,6 +2273,19 @@ } } } + }, + "subscription_hooks": { + "type": "object", + "description": "Configuration for subscription custom modules that are executed when events are received from a broker.", + "additionalProperties": false, + "properties": { + "max_concurrent_event_receive_handlers": { + "type": "integer", + "description": "The maximum number of concurrent event receive handlers. This controls the concurrency of the OnReceiveEvents custom modules.", + "minimum": 1, + "default": 100 + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index a43691cc12..d010d457d7 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -325,6 +325,8 @@ events: urls: - 'redis://localhost:6379/11' cluster_enabled: true + subscription_hooks: + max_concurrent_event_receive_handlers: 100 engine: enable_single_flight: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index c14af6023c..714acd66a9 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -289,6 +289,9 @@ "Nats": null, "Kafka": null, "Redis": null + }, + "SubscriptionHooks": { + "MaxConcurrentEventReceiveHandlers": 100 } }, "CacheWarmup": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d2a5695072..003883b338 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -635,6 +635,9 @@ "ClusterEnabled": true } ] + }, + "SubscriptionHooks": { + "MaxConcurrentEventReceiveHandlers": 100 } }, "CacheWarmup": { diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go index abab8b8ef1..e07fc7f81a 100644 --- a/router/pkg/pubsub/datasource/hooks.go +++ b/router/pkg/pubsub/datasource/hooks.go @@ -14,7 +14,8 @@ type OnReceiveEventsFn func(ctx context.Context, subConf SubscriptionEventConfig // Hooks contains hooks for the pubsub providers and data sources type Hooks struct { - SubscriptionOnStart []SubscriptionOnStartFn - OnReceiveEvents []OnReceiveEventsFn - OnPublishEvents []OnPublishEventsFn + SubscriptionOnStart []SubscriptionOnStartFn + OnReceiveEvents []OnReceiveEventsFn + OnPublishEvents []OnPublishEventsFn + MaxConcurrentOnReceiveHandlers int } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 95289bb313..b0ef4dbd71 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -2,6 +2,7 @@ package datasource import ( "context" + "sync" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" @@ -31,34 +32,30 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { } return } - // If there are hooks, we should apply them separated for each subscription - for ctx, subId := range s.eventUpdater.Subscriptions() { - processedEvents, err := applyStreamEventHooks( - ctx, - s.subscriptionEventConfiguration, - events, - s.hooks.OnReceiveEvents, - ) - // updates the events even if the hooks fail - // if a hook doesn't want to send the events, it should return no events! - for _, event := range processedEvents { - s.eventUpdater.UpdateSubscription(subId, event.GetData()) - } - if err != nil { - // For all errors, log them - if s.logger != nil { - s.logger.Error( - "An error occurred while processing stream events hooks", - zap.Error(err), - zap.String("provider_type", string(s.subscriptionEventConfiguration.ProviderType())), - zap.String("provider_id", s.subscriptionEventConfiguration.ProviderID()), - zap.String("field_name", s.subscriptionEventConfiguration.RootFieldName()), - ) - } - // Always close the subscription when a hook reports an error to avoid inconsistent state. - s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindNormal, subId) - } + + subscriptions := s.eventUpdater.Subscriptions() + limit := max(s.hooks.MaxConcurrentOnReceiveHandlers, 1) + semaphore := make(chan struct{}, limit) + wg := sync.WaitGroup{} + errCh := make(chan error, len(subscriptions)) + + for ctx, subId := range subscriptions { + semaphore <- struct{}{} // Acquire a slot + eventsCopy := copyEvents(events) + wg.Add(1) + go s.updateSubscription(ctx, &wg, errCh, semaphore, subId, eventsCopy) } + + doneLogging := make(chan struct{}) + go func() { + s.deduplicateAndLogErrors(errCh, len(subscriptions)) + doneLogging <- struct{}{} + }() + + wg.Wait() + close(semaphore) + close(errCh) + <-doneLogging } func (s *subscriptionEventUpdater) Complete() { @@ -73,9 +70,9 @@ func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { s.hooks = hooks } -// applyStreamEventHooks processes events through a chain of hook functions +// applyReceiveEventHooks processes events through a chain of hook functions // Each hook receives the result from the previous hook, creating a proper middleware pipeline -func applyStreamEventHooks( +func applyReceiveEventHooks( ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent, @@ -97,6 +94,74 @@ func applyStreamEventHooks( return currentEvents, nil } +func copyEvents(in []StreamEvent) []StreamEvent { + out := make([]StreamEvent, len(in)) + for i := range in { + out[i] = in[i].Clone() + } + return out +} + +func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) { + defer wg.Done() + defer func() { + <-semaphore // Release the slot when done + }() + + hooks := s.hooks.OnReceiveEvents + + // modify events with hooks + var err error + for i := range hooks { + events, err = hooks[i](ctx, s.subscriptionEventConfiguration, events) + if err != nil { + errCh <- err + } + } + + // send events to the subscription, + // regardless if there was an error during hook processing. + // If no events should be sent, hook must return no events. + for _, event := range events { + s.eventUpdater.UpdateSubscription(subID, event.GetData()) + } + + // In case there was an error we close the affected subscription. + if err != nil { + s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindNormal, subID) + } +} + +// deduplicateAndLogErrors collects errors from errCh +// and deduplicates them based on their err.Error() value. +// Afterwards it uses s.logger to log the message. +func (s *subscriptionEventUpdater) deduplicateAndLogErrors(errCh chan error, size int) { + if s.logger == nil { + return + } + + errs := make(map[string]int, size) + for err := range errCh { + amount, found := errs[err.Error()] + if found { + errs[err.Error()] = amount + 1 + continue + } + errs[err.Error()] = 1 + } + + for err, amount := range errs { + s.logger.Error( + "some handlers have thrown an error", + zap.String("error", err), + zap.Int("amount_handlers", amount), + zap.String("provider_type", string(s.subscriptionEventConfiguration.ProviderType())), + zap.String("provider_id", s.subscriptionEventConfiguration.ProviderID()), + zap.String("field_name", s.subscriptionEventConfiguration.RootFieldName()), + ) + } +} + func NewSubscriptionEventUpdater( cfg SubscriptionEventConfiguration, hooks Hooks, diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 79fd140a51..d5ba1fcd90 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -302,7 +302,7 @@ func TestNewSubscriptionEventUpdater(t *testing.T) { assert.Equal(t, mockUpdater, concreteUpdater.eventUpdater) } -func TestApplyStreamEventHooks_NoHooks(t *testing.T) { +func TestApplyReceiveEventHooks_NoHooks(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -313,13 +313,13 @@ func TestApplyStreamEventHooks_NoHooks(t *testing.T) { &testEvent{data: []byte("test data")}, } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) assert.NoError(t, err) assert.Equal(t, originalEvents, result) } -func TestApplyStreamEventHooks_SingleHook_Success(t *testing.T) { +func TestApplyReceiveEventHooks_SingleHook_Success(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -337,13 +337,13 @@ func TestApplyStreamEventHooks_SingleHook_Success(t *testing.T) { return modifiedEvents, nil } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) assert.NoError(t, err) assert.Equal(t, modifiedEvents, result) } -func TestApplyStreamEventHooks_SingleHook_Error(t *testing.T) { +func TestApplyReceiveEventHooks_SingleHook_Error(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -359,14 +359,14 @@ func TestApplyStreamEventHooks_SingleHook_Error(t *testing.T) { return nil, hookError } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) assert.Error(t, err) assert.Equal(t, hookError, err) assert.Nil(t, result) } -func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { +func TestApplyReceiveEventHooks_MultipleHooks_Success(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -393,7 +393,7 @@ func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { return []StreamEvent{&testEvent{data: []byte("final")}}, nil } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) select { case receivedArgs1 := <-receivedArgs1: @@ -424,7 +424,7 @@ func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { assert.Equal(t, "final", string(result[0].GetData())) } -func TestApplyStreamEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { +func TestApplyReceiveEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -452,7 +452,7 @@ func TestApplyStreamEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { return []StreamEvent{&testEvent{data: []byte("final")}}, nil } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) assert.Error(t, err) assert.Equal(t, middleHookError, err) @@ -622,6 +622,8 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite mockUpdater.AssertNotCalled(t, "UpdateSubscription") mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) - msgs := logObserver.FilterMessageSnippet("An error occurred while processing stream events hooks").TakeAll() - assert.Equal(t, 1, len(msgs)) + // log error messages for hooks are written async, we need to wait for them to be written + assert.Eventually(t, func() bool { + return len(logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll()) == 1 + }, time.Second, 10*time.Millisecond, "expected one deduplicated error log") }