Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions openfeature-provider/go/confidence/local_resolver_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"os"
"strconv"
"sync"
"time"

adminv1 "github.com/spotify/confidence-resolver/openfeature-provider/go/confidence/proto/confidence/flags/admin/v1"
Expand Down Expand Up @@ -33,6 +34,7 @@ type LocalResolverFactory struct {
flagLogger WasmFlagLogger
cancelFunc context.CancelFunc
logPollInterval time.Duration
wg sync.WaitGroup
}

// NewLocalResolverFactory creates a new LocalResolverFactory with gRPC clients and WASM bytes
Expand Down Expand Up @@ -182,7 +184,9 @@ func (f *LocalResolverFactory) startScheduledTasks(parentCtx context.Context) {
f.cancelFunc = cancel

// Ticker for state fetching and log flushing using StateProvider
f.wg.Add(1)
go func() {
defer f.wg.Done()
ticker := time.NewTicker(f.logPollInterval)
defer ticker.Stop()

Expand Down Expand Up @@ -212,15 +216,24 @@ func (f *LocalResolverFactory) startScheduledTasks(parentCtx context.Context) {

// Shutdown stops all scheduled tasks and cleans up resources
func (f *LocalResolverFactory) Shutdown(ctx context.Context) {
log.Println("Shutting down local resolver factory")
if f.cancelFunc != nil {
f.cancelFunc()
log.Println("Cancelled scheduled tasks")
}
if f.flagLogger != nil {
f.flagLogger.Shutdown()
}
// Close resolver API first (which flushes final logs)
if f.resolverAPI != nil {
f.resolverAPI.Close(ctx)
log.Println("Closed resolver API")
}
// Wait for background goroutines to exit
f.wg.Wait()
// Then shutdown flag logger (which waits for log sends to complete)
if f.flagLogger != nil {
f.flagLogger.Shutdown()
log.Println("Shut down flag logger")
}
log.Println("Local resolver factory shut down")
}

// GetSwapResolverAPI returns the SwapWasmResolverApi
Expand Down
171 changes: 171 additions & 0 deletions openfeature-provider/go/confidence/local_resolver_factory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package confidence

import (
"context"
"testing"

resolverv1 "github.com/spotify/confidence-resolver/openfeature-provider/go/confidence/proto/confidence/flags/resolverinternal"
)

// mockWasmFlagLoggerForFactory is a mock implementation for testing factory shutdown
type mockWasmFlagLoggerForFactory struct {
shutdownCalled bool
writeCalled bool
onShutdown func()
onWrite func(ctx context.Context, request *resolverv1.WriteFlagLogsRequest) error
}

func (m *mockWasmFlagLoggerForFactory) Write(ctx context.Context, request *resolverv1.WriteFlagLogsRequest) error {
m.writeCalled = true
if m.onWrite != nil {
return m.onWrite(ctx, request)
}
return nil
}

func (m *mockWasmFlagLoggerForFactory) Shutdown() {
m.shutdownCalled = true
if m.onShutdown != nil {
m.onShutdown()
}
}

func TestLocalResolverFactory_ShutdownOrder(t *testing.T) {
// Track the order in which shutdown methods are called
var callOrder []string

mockLogger := &mockWasmFlagLoggerForFactory{
onShutdown: func() {
callOrder = append(callOrder, "logger")
},
}

factory := &LocalResolverFactory{
cancelFunc: func() {
callOrder = append(callOrder, "cancel")
},
flagLogger: mockLogger,
resolverAPI: nil,
}

ctx := context.Background()
factory.Shutdown(ctx)

// Verify shutdown was called
if !mockLogger.shutdownCalled {
t.Error("Expected flag logger Shutdown to be called")
}

// Verify order: cancel should be called before logger shutdown
if len(callOrder) != 2 {
t.Errorf("Expected 2 shutdown calls, got %d", len(callOrder))
}
if len(callOrder) >= 2 {
if callOrder[0] != "cancel" {
t.Errorf("Expected cancel to be called first, but got %s", callOrder[0])
}
if callOrder[1] != "logger" {
t.Errorf("Expected logger to be called second, but got %s", callOrder[1])
}
}
}

// mockResolverAPI is a mock implementation for testing shutdown order
type mockResolverAPI struct {
closeCalled bool
onClose func()
}

func (m *mockResolverAPI) Close(ctx context.Context) {
m.closeCalled = true
if m.onClose != nil {
m.onClose()
}
}

func TestLocalResolverFactory_ShutdownOrderWithResolver(t *testing.T) {
// This test verifies the critical shutdown order:
// 1. Cancel context
// 2. Wait for background tasks
// 3. Close resolver API (which flushes final logs)
// 4. Shutdown logger (which waits for log sends to complete)
//
// This order ensures logs generated during resolver Close are actually sent.

var callOrder []string
var logsSent bool

mockLogger := &mockWasmFlagLoggerForFactory{
onWrite: func(ctx context.Context, request *resolverv1.WriteFlagLogsRequest) error {
callOrder = append(callOrder, "logger-write")
logsSent = true
return nil
},
onShutdown: func() {
callOrder = append(callOrder, "logger-shutdown")
// At this point, logs should already be sent
if !logsSent {
t.Error("Logger shutdown called before logs were sent!")
}
},
}

mockResolver := &mockResolverAPI{
onClose: func() {
callOrder = append(callOrder, "resolver-close")
// Simulate resolver flushing logs on close
mockLogger.Write(context.Background(), &resolverv1.WriteFlagLogsRequest{})
},
}

factory := &LocalResolverFactory{
cancelFunc: func() {
callOrder = append(callOrder, "cancel")
},
flagLogger: mockLogger,
resolverAPI: (*SwapWasmResolverApi)(nil), // Can't easily mock this, test order instead
}

// Manually test the shutdown sequence - simulating the CORRECT order
// This test verifies our fix works correctly

if factory.cancelFunc != nil {
factory.cancelFunc()
}

// Wait for background tasks (part of our fix)
factory.wg.Wait()

// Close resolver FIRST (which generates logs)
mockResolver.Close(context.Background())

// Then shutdown logger (which waits for logs to be sent)
if factory.flagLogger != nil {
factory.flagLogger.Shutdown()
}

// Verify the CORRECT order: cancel → resolver-close → logger-write → logger-shutdown
expectedOrder := []string{"cancel", "resolver-close", "logger-write", "logger-shutdown"}
if len(callOrder) != len(expectedOrder) {
t.Errorf("Expected %d calls, got %d: %v", len(expectedOrder), len(callOrder), callOrder)
}

for i, expected := range expectedOrder {
if i < len(callOrder) && callOrder[i] != expected {
t.Errorf("Expected call %d to be '%s', got '%s'", i, expected, callOrder[i])
}
}

// Verify logs were sent before logger shutdown
if !logsSent {
t.Error("Expected logs to be sent during shutdown")
}

// Verify all components were called
if !mockResolver.closeCalled {
t.Error("Expected resolver Close to be called")
}
if !mockLogger.shutdownCalled {
t.Error("Expected logger Shutdown to be called")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,14 @@ func (p *LocalResolverProvider) Hooks() []openfeature.Hook {
return []openfeature.Hook{}
}

// Shutdown closes the provider and cleans up resources
// Init initializes the provider (part of StateHandler interface)
func (p *LocalResolverProvider) Init(evaluationContext openfeature.EvaluationContext) error {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The absense of this function was the reason that openfeature.Shutdown() didn't work.

// Provider is already initialized in NewProvider, nothing to do here
// TODO move the bulk of the initialization to this place.
return nil
}

// Shutdown closes the provider and cleans up resources (part of StateHandler interface)
func (p *LocalResolverProvider) Shutdown() {
if p.factory != nil {
p.factory.Shutdown(context.Background())
Expand Down
2 changes: 1 addition & 1 deletion openfeature-provider/go/confidence/resolver_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (r *ResolverApi) FlushLogs() error {
}

// Write logs via the flag logger
if r.flagLogger != nil && (len(logRequest.FlagAssigned) > 0 || len(logRequest.ClientResolveInfo) > 0 || len(logRequest.FlagResolveInfo) > 0) {
if len(logRequest.FlagAssigned) > 0 || len(logRequest.ClientResolveInfo) > 0 || len(logRequest.FlagResolveInfo) > 0 {
if err := r.flagLogger.Write(ctx, logRequest); err != nil {
log.Printf("Failed to write flushed logs: %v", err)
}
Expand Down
12 changes: 7 additions & 5 deletions openfeature-provider/go/demo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func main() {
if err != nil {
log.Fatalf("Failed to create provider: %v", err)
}
defer provider.Shutdown()
defer openfeature.Shutdown()
log.Println("Confidence provider created successfully")

// Register with OpenFeature
Expand Down Expand Up @@ -75,8 +75,8 @@ func main() {

// Run 5 concurrent threads continuously for 5 second
var wg sync.WaitGroup
numThreads := 5
runDuration := 5 * time.Second
numThreads := 1
runDuration := 3 * time.Second

log.Printf("Starting %d threads to run for %v to test reload and flush...", numThreads, runDuration)
log.Println("")
Expand Down Expand Up @@ -113,8 +113,8 @@ func main() {
}
iteration++

// Small sleep to avoid tight loop
time.Sleep(1 * time.Millisecond)
// large sleep to avoid tight loop
time.Sleep(250 * time.Millisecond)
}

// Update shared counters atomically
Expand All @@ -140,6 +140,8 @@ func main() {
log.Printf("Average latency: %.2f ms/request", duration.Seconds()*1000/float64(totalRequests))
log.Println("Check logs above for per-thread statistics and state reload/flush messages")
log.Println("")

log.Println("At the end of main... shutting down...")
}

func getEnvOrDefault(key, defaultValue string) string {
Expand Down
Loading