Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions router/pkg/pubsub/datasource/pubsubprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package datasource

import (
"context"
"fmt"

"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

type PubSubProvider struct {
Expand All @@ -17,12 +19,39 @@ type PubSubProvider struct {

// applyPublishEventHooks processes events through a chain of hook functions
// Each hook receives the result from the previous hook, creating a proper middleware pipeline
func applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn, hooks []OnPublishEventsFn) ([]StreamEvent, error) {
currentEvents := events
for _, hook := range hooks {
func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) (currentEvents []StreamEvent, err error) {
defer func() {
if r := recover(); r != nil {
if p.Logger != nil {
p.Logger.
WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)).
Error("[Recovery from handler panic]",
zap.Any("error", r),
)
}

switch v := r.(type) {
case error:
err = v
default:
err = fmt.Errorf("%v", r)
}
}
}()

currentEvents = events
for _, hook := range p.hooks.OnPublishEvents {
var err error
currentEvents, err = hook(ctx, cfg, currentEvents, eventBuilder)
currentEvents, err = hook(ctx, cfg, currentEvents, p.eventBuilder)
if err != nil {
p.Logger.Error(
"error applying publish event hooks",
zap.Error(err),
zap.String("provider_id", cfg.ProviderID()),
zap.String("provider_type_id", string(cfg.ProviderType())),
zap.String("field_name", cfg.RootFieldName()),
)

return currentEvents, err
}
}
Expand Down Expand Up @@ -60,16 +89,7 @@ func (p *PubSubProvider) Publish(ctx context.Context, cfg PublishEventConfigurat
return p.Adapter.Publish(ctx, cfg, events)
}

processedEvents, hooksErr := applyPublishEventHooks(ctx, cfg, events, p.eventBuilder, p.hooks.OnPublishEvents)
if hooksErr != nil {
p.Logger.Error(
"error applying publish event hooks",
zap.Error(hooksErr),
zap.String("provider_id", cfg.ProviderID()),
zap.String("provider_type_id", string(cfg.ProviderType())),
zap.String("field_name", cfg.RootFieldName()),
)
}
processedEvents, hooksErr := p.applyPublishEventHooks(ctx, cfg, events)

errPublish := p.Adapter.Publish(ctx, cfg, processedEvents)
if errPublish != nil {
Expand Down
107 changes: 102 additions & 5 deletions router/pkg/pubsub/datasource/pubsubprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,14 @@ func TestApplyPublishEventHooks_NoHooks(t *testing.T) {
originalEvents := []StreamEvent{
&testEvent{mutableTestEvent("test data")},
}
provider := &PubSubProvider{
Logger: zap.NewNop(),
hooks: Hooks{
OnPublishEvents: []OnPublishEventsFn{},
},
}

result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{})
result, err := provider.applyPublishEventHooks(ctx, config, originalEvents)

assert.NoError(t, err)
assert.Equal(t, originalEvents, result)
Expand All @@ -426,7 +432,14 @@ func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) {
return modifiedEvents, nil
}

result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook})
provider := &PubSubProvider{
Logger: zap.NewNop(),
hooks: Hooks{
OnPublishEvents: []OnPublishEventsFn{hook},
},
}

result, err := provider.applyPublishEventHooks(ctx, config, originalEvents)

assert.NoError(t, err)
assert.Equal(t, modifiedEvents, result)
Expand All @@ -448,7 +461,14 @@ func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) {
return nil, hookError
}

result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook})
provider := &PubSubProvider{
Logger: zap.NewNop(),
hooks: Hooks{
OnPublishEvents: []OnPublishEventsFn{hook},
},
}

result, err := provider.applyPublishEventHooks(ctx, config, originalEvents)

assert.Error(t, err)
assert.Equal(t, hookError, err)
Expand Down Expand Up @@ -476,7 +496,14 @@ func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) {
return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil
}

result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3})
provider := &PubSubProvider{
Logger: zap.NewNop(),
hooks: Hooks{
OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3},
},
}

result, err := provider.applyPublishEventHooks(ctx, config, originalEvents)

assert.NoError(t, err)
assert.Len(t, result, 1)
Expand Down Expand Up @@ -505,9 +532,79 @@ func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) {
return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil
}

result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3})
provider := &PubSubProvider{
Logger: zap.NewNop(),
hooks: Hooks{
OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3},
},
}

result, err := provider.applyPublishEventHooks(ctx, config, originalEvents)

assert.Error(t, err)
assert.Equal(t, middleHookError, err)
assert.Nil(t, result)
}

func TestApplyPublishEventHooks_PanicRecovery(t *testing.T) {
panicErr := errors.New("panic error")

tests := []struct {
name string
panicValue any
expectedErr error
expectedErrText string
}{
{
name: "error type",
panicValue: panicErr,
expectedErr: panicErr,
},
{
name: "string type",
panicValue: "panic string message",
expectedErrText: "panic string message",
},
{
name: "other type",
panicValue: 42,
expectedErrText: "42",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
config := &testPublishConfig{
providerID: "test-provider",
providerType: ProviderTypeKafka,
fieldName: "testField",
}
originalEvents := []StreamEvent{
&testEvent{mutableTestEvent("original")},
}

hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) {
panic(tt.panicValue)
}

provider := &PubSubProvider{
Logger: zap.NewNop(),
hooks: Hooks{
OnPublishEvents: []OnPublishEventsFn{hook},
},
}

result, err := provider.applyPublishEventHooks(ctx, config, originalEvents)

assert.Error(t, err)
if tt.expectedErr != nil {
assert.Equal(t, tt.expectedErr, err)
}
if tt.expectedErrText != "" {
assert.Contains(t, err.Error(), tt.expectedErrText)
}
assert.Equal(t, originalEvents, result)
})
}
}
18 changes: 18 additions & 0 deletions router/pkg/pubsub/datasource/subscription_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package datasource
import (
"encoding/json"
"errors"
"fmt"

"github.com/cespare/xxhash/v2"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error
Expand Down Expand Up @@ -48,6 +50,22 @@ func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []by
}

func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) {
defer func() {
if r := recover(); r != nil {
s.logger.
WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)).
Error("[Recovery from handler panic]",
zap.Any("error", r),
)
switch v := r.(type) {
case error:
err = v
default:
err = fmt.Errorf("%v", r)
}
}
}()

for _, fn := range s.hooks.SubscriptionOnStart {
conf, errConf := s.SubscriptionEventConfiguration(input)
if errConf != nil {
Expand Down
69 changes: 69 additions & 0 deletions router/pkg/pubsub/datasource/subscription_datasource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,72 @@ func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) {
// Test that it implements HookableSubscriptionDataSource interface
var _ resolve.HookableSubscriptionDataSource = dataSource
}

func TestPubSubSubscriptionDataSource_SubscriptionOnStart_PanicRecovery(t *testing.T) {
panicErr := errors.New("panic error")

tests := []struct {
name string
panicValue any
expectedErr error
expectedErrText string
}{
{
name: "error type",
panicValue: panicErr,
expectedErr: panicErr,
},
{
name: "string type",
panicValue: "panic string message",
expectedErrText: "panic string message",
},
{
name: "other type",
panicValue: 42,
expectedErrText: "42",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockAdapter := NewMockProvider(t)
uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error {
return nil
}

dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder)

// Add hook that panics
hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error {
panic(tt.panicValue)
}

dataSource.SetHooks(Hooks{
SubscriptionOnStart: []SubscriptionOnStartFn{hook},
})

testConfig := testSubscriptionEventConfiguration{
Topic: "test-topic",
Subject: "test-subject",
}
input, err := json.Marshal(testConfig)
assert.NoError(t, err)

hookCtx := resolve.StartupHookContext{
Context: context.Background(),
Updater: func(data []byte) {},
}

err = dataSource.SubscriptionOnStart(hookCtx, input)

assert.Error(t, err)
if tt.expectedErr != nil {
assert.Equal(t, tt.expectedErr, err)
}
if tt.expectedErrText != "" {
assert.Contains(t, err.Error(), tt.expectedErrText)
}
})
}
}
17 changes: 16 additions & 1 deletion router/pkg/pubsub/datasource/subscription_event_updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

// SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface
Expand Down Expand Up @@ -73,7 +74,10 @@ func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) {
func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) {
defer wg.Done()
defer func() {
<-semaphore // Release the slot when done
if r := recover(); r != nil {
s.recoverPanic(subID, r)
}
<-semaphore // release the slot when done
}()

hooks := s.hooks.OnReceiveEvents
Expand All @@ -100,6 +104,17 @@ func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *s
}
}

func (s *subscriptionEventUpdater) recoverPanic(subID resolve.SubscriptionIdentifier, err any) {
s.logger.
WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)).
Error("[Recovery from handler panic]",
zap.Int64("subscription_id", subID.SubscriptionID),
zap.Any("error", err),
)

s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindDownstreamServiceError, subID)
}

// deduplicateAndLogErrors collects errors from errCh
// and deduplicates them based on their err.Error() value.
// Afterwards it uses s.logger to log the message.
Expand Down
Loading
Loading