Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions openfeature-provider/go/confidence/integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package confidence

import (
"context"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/open-feature/go-sdk/openfeature"
resolverv1 "github.com/spotify/confidence-resolver/openfeature-provider/go/confidence/proto/confidence/flags/resolverinternal"
"github.com/tetratelabs/wazero"
"google.golang.org/grpc"
)

// mockStateProvider provides test state for integration testing
type mockStateProvider struct {
state []byte
}

func (m *mockStateProvider) Provide(ctx context.Context) ([]byte, error) {
return m.state, nil
}

// trackingFlagLogger wraps a real GrpcWasmFlagLogger with a mocked connection
type trackingFlagLogger struct {
actualLogger WasmFlagLogger
logsSentCount int32
shutdownCalled bool
mu sync.Mutex
// Track when async operations complete
lastWriteCompleted chan struct{}
}

func (t *trackingFlagLogger) Write(ctx context.Context, request *resolverv1.WriteFlagLogsRequest) error {
atomic.AddInt32(&t.logsSentCount, int32(len(request.FlagAssigned)))
return t.actualLogger.Write(ctx, request)
}

func (t *trackingFlagLogger) Shutdown() {
t.mu.Lock()
t.shutdownCalled = true
t.mu.Unlock()
t.actualLogger.Shutdown()
}

func (t *trackingFlagLogger) GetLogsSentCount() int32 {
return atomic.LoadInt32(&t.logsSentCount)
}

func (t *trackingFlagLogger) WasShutdownCalled() bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.shutdownCalled
}

// mockGrpcStubForIntegration provides a mock gRPC stub that tracks async operations
type mockGrpcStubForIntegration struct {
resolverv1.InternalFlagLoggerServiceClient
callsReceived int32
onCallReceived chan struct{}
}

func (m *mockGrpcStubForIntegration) WriteFlagLogs(ctx context.Context, req *resolverv1.WriteFlagLogsRequest, opts ...grpc.CallOption) (*resolverv1.WriteFlagLogsResponse, error) {
atomic.AddInt32(&m.callsReceived, 1)
// Signal that a call was received
select {
case m.onCallReceived <- struct{}{}:
default:
}
// Simulate some processing time to verify shutdown waits for completion
time.Sleep(50 * time.Millisecond)
return &resolverv1.WriteFlagLogsResponse{}, nil
}

func (m *mockGrpcStubForIntegration) GetCallsReceived() int32 {
return atomic.LoadInt32(&m.callsReceived)
}

// TestIntegration_OpenFeatureShutdownFlushesLogs tests the full integration:
// - Real OpenFeature SDK
// - Real provider with all components
// - Mock state provider (using test state)
// - Actual GrpcWasmFlagLogger with mocked gRPC connection
// - Verifies logs are flushed and gRPC calls complete on openfeature.Shutdown()
// This test specifically verifies the shutdown bug fix where the GrpcWasmFlagLogger's
// async goroutines complete before Shutdown() returns, ensuring no data loss.
func TestIntegration_OpenFeatureShutdownFlushesLogs(t *testing.T) {
// Load test state
testState := loadTestResolverState(t)
accountID := loadTestAccountID(t)

ctx := context.Background()

// Create mock state provider
stateProvider := &mockStateProvider{
state: testState,
}

// Create tracking logger with actual GrpcWasmFlagLogger and mocked connection
mockStub := &mockGrpcStubForIntegration{
onCallReceived: make(chan struct{}, 100), // Buffer to prevent blocking
}
actualGrpcLogger := NewGrpcWasmFlagLogger(mockStub)

trackingLogger := &trackingFlagLogger{
actualLogger: actualGrpcLogger,
lastWriteCompleted: make(chan struct{}, 1),
}

// Create provider with test state
provider, err := createProviderWithTestState(ctx, stateProvider, accountID, trackingLogger)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}

// Register with OpenFeature
err = openfeature.SetProviderAndWait(provider)
if err != nil {
t.Fatalf("Failed to set provider: %v", err)
}

// Create client and evaluate flags
client := openfeature.NewClient("integration-test")
evalCtx := openfeature.NewEvaluationContext(
"tutorial_visitor",
map[string]interface{}{
"visitor_id": "tutorial_visitor",
},
)

// Evaluate the tutorial-feature flag (this should generate logs)
// This flag exists in the test state and should resolve successfully
numEvaluations := 5
for i := 0; i < numEvaluations; i++ {
result, _ := client.ObjectValueDetails(ctx, "tutorial-feature", map[string]interface{}{}, evalCtx)
if i == 0 {
t.Logf("First evaluation result: %+v", result)
}
}

// Now shutdown - this should flush all logs
openfeature.Shutdown()

// Verify shutdown was called
if !trackingLogger.WasShutdownCalled() {
t.Error("Expected logger shutdown to be called")
}

// Verify logs were flushed
finalLogCount := trackingLogger.GetLogsSentCount()
if finalLogCount == 0 {
t.Error("Expected logs to be flushed during shutdown, but no logs were sent")
}

// Verify that the mock gRPC connection actually received the calls
// This proves that the connection completed before shutdown returned
grpcCallsReceived := mockStub.GetCallsReceived()
if grpcCallsReceived == 0 {
t.Error("Expected mock gRPC connection to receive calls, but none were received")
}

t.Logf("Successfully flushed %d log entries via %d gRPC calls during shutdown", finalLogCount, grpcCallsReceived)
}

// createProviderWithTestState creates a provider with mock state provider and tracking logger
func createProviderWithTestState(
ctx context.Context,
stateProvider StateProvider,
accountID string,
logger WasmFlagLogger,
) (*LocalResolverProvider, error) {
// Create wazero runtime
runtimeConfig := wazero.NewRuntimeConfig()
runtime := wazero.NewRuntimeWithConfig(ctx, runtimeConfig)

// Create factory with custom state provider and logger
factory, err := NewLocalResolverFactoryWithStateProviderAndLogger(
ctx,
runtime,
defaultWasmBytes,
stateProvider,
accountID,
logger,
)
if err != nil {
return nil, err
}

// Create provider with the client secret from test state
// The test state includes client secret: mkjJruAATQWjeY7foFIWfVAcBWnci2YF
provider := NewLocalResolverProvider(factory, "mkjJruAATQWjeY7foFIWfVAcBWnci2YF")
return provider, nil
}

// NewLocalResolverFactoryWithStateProviderAndLogger creates a factory with custom state provider and logger for testing
func NewLocalResolverFactoryWithStateProviderAndLogger(
ctx context.Context,
runtime wazero.Runtime,
wasmBytes []byte,
stateProvider StateProvider,
accountId string,
flagLogger WasmFlagLogger,
) (*LocalResolverFactory, error) {
// Get initial state from provider
initialState, err := stateProvider.Provide(ctx)
if err != nil {
initialState = []byte{}
}

// Create SwapWasmResolverApi with initial state
resolverAPI, err := NewSwapWasmResolverApi(ctx, runtime, wasmBytes, flagLogger, initialState, accountId)
if err != nil {
return nil, err
}

// Create factory
factory := &LocalResolverFactory{
resolverAPI: resolverAPI,
stateProvider: stateProvider,
accountId: accountId,
flagLogger: flagLogger,
logPollInterval: getPollIntervalSeconds(),
}

// Start scheduled tasks
factory.startScheduledTasks(ctx)

return factory, nil
}
19 changes: 16 additions & 3 deletions openfeature-provider/go/confidence/local_resolver_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"os"
"strconv"
"sync"
"time"

adminv1 "github.com/spotify/confidence-resolver/openfeature-provider/go/confidence/proto/confidence/flags/admin/v1"
Expand Down Expand Up @@ -33,6 +34,7 @@ type LocalResolverFactory struct {
flagLogger WasmFlagLogger
cancelFunc context.CancelFunc
logPollInterval time.Duration
wg sync.WaitGroup
}

// NewLocalResolverFactory creates a new LocalResolverFactory with gRPC clients and WASM bytes
Expand Down Expand Up @@ -182,7 +184,9 @@ func (f *LocalResolverFactory) startScheduledTasks(parentCtx context.Context) {
f.cancelFunc = cancel

// Ticker for state fetching and log flushing using StateProvider
f.wg.Add(1)
go func() {
defer f.wg.Done()
ticker := time.NewTicker(f.logPollInterval)
defer ticker.Stop()

Expand Down Expand Up @@ -212,15 +216,24 @@ func (f *LocalResolverFactory) startScheduledTasks(parentCtx context.Context) {

// Shutdown stops all scheduled tasks and cleans up resources
func (f *LocalResolverFactory) Shutdown(ctx context.Context) {
log.Println("Shutting down local resolver factory")
if f.cancelFunc != nil {
f.cancelFunc()
log.Println("Cancelled scheduled tasks")
}
if f.flagLogger != nil {
f.flagLogger.Shutdown()
}
// Close resolver API first (which flushes final logs)
if f.resolverAPI != nil {
f.resolverAPI.Close(ctx)
log.Println("Closed resolver API")
}
// Wait for background goroutines to exit
f.wg.Wait()
// Then shutdown flag logger (which waits for log sends to complete)
if f.flagLogger != nil {
f.flagLogger.Shutdown()
log.Println("Shut down flag logger")
}
log.Println("Local resolver factory shut down")
}

// GetSwapResolverAPI returns the SwapWasmResolverApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,14 @@ func (p *LocalResolverProvider) Hooks() []openfeature.Hook {
return []openfeature.Hook{}
}

// Shutdown closes the provider and cleans up resources
// Init initializes the provider (part of StateHandler interface)
func (p *LocalResolverProvider) Init(evaluationContext openfeature.EvaluationContext) error {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The absense of this function was the reason that openfeature.Shutdown() didn't work.

// Provider is already initialized in NewProvider, nothing to do here
// TODO move the bulk of the initialization to this place.
return nil
}

// Shutdown closes the provider and cleans up resources (part of StateHandler interface)
func (p *LocalResolverProvider) Shutdown() {
if p.factory != nil {
p.factory.Shutdown(context.Background())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

// Helper to load test data from the data directory
func loadTestResolverState(t *testing.T) []byte {
dataPath := filepath.Join("..", "..", "data", "resolver_state_current.pb")
dataPath := filepath.Join("..", "..", "..", "data", "resolver_state_current.pb")
data, err := os.ReadFile(dataPath)
if err != nil {
t.Skipf("Skipping test - could not load test resolver state: %v", err)
Expand All @@ -27,7 +27,7 @@ func loadTestResolverState(t *testing.T) []byte {
}

func loadTestAccountID(t *testing.T) string {
dataPath := filepath.Join("..", "..", "data", "account_id")
dataPath := filepath.Join("..", "..", "..", "data", "account_id")
data, err := os.ReadFile(dataPath)
if err != nil {
t.Skipf("Skipping test - could not load test account ID: %v", err)
Expand Down
4 changes: 3 additions & 1 deletion openfeature-provider/go/demo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func main() {
if err != nil {
log.Fatalf("Failed to create provider: %v", err)
}
defer provider.Shutdown()
defer openfeature.Shutdown()
log.Println("Confidence provider created successfully")

// Register with OpenFeature
Expand Down Expand Up @@ -140,6 +140,8 @@ func main() {
log.Printf("Average latency: %.2f ms/request", duration.Seconds()*1000/float64(totalRequests))
log.Println("Check logs above for per-thread statistics and state reload/flush messages")
log.Println("")

log.Println("At the end of main... shutting down...")
}

func getEnvOrDefault(key, defaultValue string) string {
Expand Down
Loading