diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 2a898b6ce3..3697229182 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -2,8 +2,10 @@ package datasource import ( "context" + "fmt" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) type PubSubProvider struct { @@ -17,12 +19,39 @@ type PubSubProvider struct { // applyPublishEventHooks processes events through a chain of hook functions // Each hook receives the result from the previous hook, creating a proper middleware pipeline -func applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn, hooks []OnPublishEventsFn) ([]StreamEvent, error) { - currentEvents := events - for _, hook := range hooks { +func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) (currentEvents []StreamEvent, err error) { + defer func() { + if r := recover(); r != nil { + if p.Logger != nil { + p.Logger. + WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). + Error("[Recovery from handler panic]", + zap.Any("error", r), + ) + } + + switch v := r.(type) { + case error: + err = v + default: + err = fmt.Errorf("%v", r) + } + } + }() + + currentEvents = events + for _, hook := range p.hooks.OnPublishEvents { var err error - currentEvents, err = hook(ctx, cfg, currentEvents, eventBuilder) + currentEvents, err = hook(ctx, cfg, currentEvents, p.eventBuilder) if err != nil { + p.Logger.Error( + "error applying publish event hooks", + zap.Error(err), + zap.String("provider_id", cfg.ProviderID()), + zap.String("provider_type_id", string(cfg.ProviderType())), + zap.String("field_name", cfg.RootFieldName()), + ) + return currentEvents, err } } @@ -60,16 +89,7 @@ func (p *PubSubProvider) Publish(ctx context.Context, cfg PublishEventConfigurat return p.Adapter.Publish(ctx, cfg, events) } - processedEvents, hooksErr := applyPublishEventHooks(ctx, cfg, events, p.eventBuilder, p.hooks.OnPublishEvents) - if hooksErr != nil { - p.Logger.Error( - "error applying publish event hooks", - zap.Error(hooksErr), - zap.String("provider_id", cfg.ProviderID()), - zap.String("provider_type_id", string(cfg.ProviderType())), - zap.String("field_name", cfg.RootFieldName()), - ) - } + processedEvents, hooksErr := p.applyPublishEventHooks(ctx, cfg, events) errPublish := p.Adapter.Publish(ctx, cfg, processedEvents) if errPublish != nil { diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 590297f689..e623de93b0 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -401,8 +401,14 @@ func TestApplyPublishEventHooks_NoHooks(t *testing.T) { originalEvents := []StreamEvent{ &testEvent{mutableTestEvent("test data")}, } + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{}, + }, + } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{}) + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.NoError(t, err) assert.Equal(t, originalEvents, result) @@ -426,7 +432,14 @@ func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { return modifiedEvents, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.NoError(t, err) assert.Equal(t, modifiedEvents, result) @@ -448,7 +461,14 @@ func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { return nil, hookError } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.Error(t, err) assert.Equal(t, hookError, err) @@ -476,7 +496,14 @@ func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.NoError(t, err) assert.Len(t, result, 1) @@ -505,9 +532,79 @@ func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.Error(t, err) assert.Equal(t, middleHookError, err) assert.Nil(t, result) } + +func TestApplyPublishEventHooks_PanicRecovery(t *testing.T) { + panicErr := errors.New("panic error") + + tests := []struct { + name string + panicValue any + expectedErr error + expectedErrText string + }{ + { + name: "error type", + panicValue: panicErr, + expectedErr: panicErr, + }, + { + name: "string type", + panicValue: "panic string message", + expectedErrText: "panic string message", + }, + { + name: "other type", + panicValue: 42, + expectedErrText: "42", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{mutableTestEvent("original")}, + } + + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { + panic(tt.panicValue) + } + + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) + + assert.Error(t, err) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } + if tt.expectedErrText != "" { + assert.Contains(t, err.Error(), tt.expectedErrText) + } + assert.Equal(t, originalEvents, result) + }) + } +} diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index fb35054bd5..c625af9c33 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -3,10 +3,12 @@ package datasource import ( "encoding/json" "errors" + "fmt" "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error @@ -48,6 +50,22 @@ func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []by } func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { + defer func() { + if r := recover(); r != nil { + s.logger. + WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). + Error("[Recovery from handler panic]", + zap.Any("error", r), + ) + switch v := r.(type) { + case error: + err = v + default: + err = fmt.Errorf("%v", r) + } + } + }() + for _, fn := range s.hooks.SubscriptionOnStart { conf, errConf := s.SubscriptionEventConfiguration(input) if errConf != nil { diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index 8bba79b259..a292f4b0f4 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -382,3 +382,72 @@ func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { // Test that it implements HookableSubscriptionDataSource interface var _ resolve.HookableSubscriptionDataSource = dataSource } + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_PanicRecovery(t *testing.T) { + panicErr := errors.New("panic error") + + tests := []struct { + name string + panicValue any + expectedErr error + expectedErrText string + }{ + { + name: "error type", + panicValue: panicErr, + expectedErr: panicErr, + }, + { + name: "string type", + panicValue: "panic string message", + expectedErrText: "panic string message", + }, + { + name: "other type", + panicValue: 42, + expectedErrText: "42", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) + + // Add hook that panics + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + panic(tt.panicValue) + } + + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + }) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + hookCtx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(hookCtx, input) + + assert.Error(t, err) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } + if tt.expectedErrText != "" { + assert.Contains(t, err.Error(), tt.expectedErrText) + } + }) + } +} diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 5c1ee69ac6..615354ba1a 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -6,6 +6,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) // SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface @@ -73,7 +74,10 @@ func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) { defer wg.Done() defer func() { - <-semaphore // Release the slot when done + if r := recover(); r != nil { + s.recoverPanic(subID, r) + } + <-semaphore // release the slot when done }() hooks := s.hooks.OnReceiveEvents @@ -100,6 +104,17 @@ func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *s } } +func (s *subscriptionEventUpdater) recoverPanic(subID resolve.SubscriptionIdentifier, err any) { + s.logger. + WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). + Error("[Recovery from handler panic]", + zap.Int64("subscription_id", subID.SubscriptionID), + zap.Any("error", err), + ) + + s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindDownstreamServiceError, subID) +} + // deduplicateAndLogErrors collects errors from errCh // and deduplicates them based on their err.Error() value. // Afterwards it uses s.logger to log the message. diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 283c624310..2c0295dd1c 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -719,3 +719,81 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite return len(logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll()) == 1 }, time.Second, 10*time.Millisecond, "expected one deduplicated error log") } + +func TestSubscriptionEventUpdater_OnReceiveEvents_PanicRecovery(t *testing.T) { + panicErr := errors.New("panic error") + + tests := []struct { + name string + panicValue any + }{ + { + name: "error type", + panicValue: panicErr, + }, + { + name: "string type", + panicValue: "panic string message", + }, + { + name: "other type", + panicValue: 42, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + core, logObserver := observer.New(zap.InfoLevel) + logger := zap.New(core) + + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{mutableTestEvent("test data")}, + } + + // Create hook that panics + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + panic(tt.panicValue) + } + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindDownstreamServiceError, subId).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + logger: logger, + } + + updater.Update(events) + + // Wait for async processing to complete and assert panic was logged + assert.Eventually(t, func() bool { + logs := logObserver.FilterMessage("[Recovery from handler panic]").All() + return len(logs) == 1 + }, 10*time.Millisecond, time.Millisecond, "expected panic recovery log") + + // Assert that subscription was closed due to panic + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindDownstreamServiceError, subId) + mockUpdater.AssertNotCalled(t, "UpdateSubscription") + + // Assert that panic was logged with correct details + logs := logObserver.FilterMessage("[Recovery from handler panic]").All() + assert.Len(t, logs, 1) + assert.Equal(t, zap.ErrorLevel, logs[0].Level) + assert.Equal(t, int64(1), logs[0].ContextMap()["subscription_id"]) + assert.NotNil(t, logs[0].ContextMap()["error"]) + }) + } +}