diff --git a/lib/utils/fncache.go b/lib/utils/fncache.go index 1e4eb28749bd8..5600310f762ec 100644 --- a/lib/utils/fncache.go +++ b/lib/utils/fncache.go @@ -182,6 +182,40 @@ func (c *FnCache) Set(key, value any) { c.SetWithTTL(key, value, c.cfg.TTL) } +// GetIfExists retrieves a value from the cache without triggering a load operation. +// It returns (value, true) if a valid, non-expired entry exists, or (nil, false) +// otherwise. If an entry is currently being loaded by FnCacheGet, Get will +// return false immediately without blocking. Get returns false for entries that +// contain errors. +// For most of the cases the FnCacheGet function should be used instead. +func (c *FnCache) GetIfExists(key any) (any, bool) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, false + } + entry := c.entries[key] + c.mu.Unlock() + + if entry == nil { + return nil, false + } + + select { + case <-entry.loaded: + if c.cfg.Clock.Now().After(entry.t.Add(entry.ttl)) { + return nil, false + } + if entry.e != nil { + return nil, false + } + return entry.v, true + default: + // Entry still loading - treat as cache miss + return nil, false + } +} + // SetWithTTL places an item in the cache with an explicit TTL. func (c *FnCache) SetWithTTL(key, value any, ttl time.Duration) { c.mu.Lock() diff --git a/lib/utils/fncache_test.go b/lib/utils/fncache_test.go index 4e3e5341975f1..aebb30b1cbbeb 100644 --- a/lib/utils/fncache_test.go +++ b/lib/utils/fncache_test.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" apiutils "github.com/gravitational/teleport/api/utils" @@ -655,3 +656,202 @@ func TestFnCacheSet(t *testing.T) { require.NoError(t, err) require.Equal(t, 100, out) } + +// TestGetIfExists tests the GetIfExists method which retrieves values without triggering loads. +func TestGetIfExists(t *testing.T) { + t.Parallel() + + t.Run("non-existent key", func(t *testing.T) { + t.Parallel() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + val, ok := cache.GetIfExists("nonexistent") + require.False(t, ok) + require.Empty(t, val) + }) + + t.Run("existing valid entry", func(t *testing.T) { + t.Parallel() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + cache.Set("test-key", "test-value") + + val, ok := cache.GetIfExists("test-key") + require.True(t, ok) + require.Equal(t, "test-value", val) + }) + + t.Run("expired entry", func(t *testing.T) { + t.Parallel() + clock := clockwork.NewFakeClock() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clock, + }) + require.NoError(t, err) + + cache.SetWithTTL("test-key", "test-value", time.Minute) + + val, ok := cache.GetIfExists("test-key") + require.True(t, ok) + require.Equal(t, "test-value", val) + + // Advance time past the TTL + clock.Advance(2 * time.Minute) + + // Entry should now be expired and not returned + val, ok = cache.GetIfExists("test-key") + require.False(t, ok) + require.Empty(t, val) + }) + + t.Run("entry with different TTL", func(t *testing.T) { + t.Parallel() + clock := clockwork.NewFakeClock() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clock, + }) + require.NoError(t, err) + + // Set entries with different TTLs + cache.SetWithTTL("short-ttl", "short-value", 30*time.Minute) + cache.SetWithTTL("long-ttl", "long-value", 2*time.Hour) + + // Both should be accessible initially + val, ok := cache.GetIfExists("short-ttl") + require.True(t, ok) + require.Equal(t, "short-value", val) + + val, ok = cache.GetIfExists("long-ttl") + require.True(t, ok) + require.Equal(t, "long-value", val) + + // Advance time to expire only the short TTL entry + clock.Advance(45 * time.Minute) + + // Short TTL entry should be expired + val, ok = cache.GetIfExists("short-ttl") + require.False(t, ok) + require.Empty(t, val) + + // Long TTL entry should still be accessible + val, ok = cache.GetIfExists("long-ttl") + require.True(t, ok) + require.Equal(t, "long-value", val) + }) + + t.Run("entry loaded with error via FnCacheGet", func(t *testing.T) { + t.Parallel() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + ctx := context.Background() + + // Load an entry that results in an error + _, err = FnCacheGet(ctx, cache, "error-key", func(ctx context.Context) (string, error) { + return "", fmt.Errorf("load error") + }) + require.Error(t, err) + + // GetIfExists should not return the error entry + val, ok := cache.GetIfExists("error-key") + require.False(t, ok) + require.Empty(t, val) + }) + + t.Run("get after remove", func(t *testing.T) { + t.Parallel() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + cache.Set("remove-test-key", "test-value") + val, ok := cache.GetIfExists("remove-test-key") + require.True(t, ok) + require.Equal(t, "test-value", val) + + cache.Remove("remove-test-key") + + val, ok = cache.GetIfExists("remove-test-key") + require.False(t, ok) + require.Empty(t, val) + }) + + t.Run("non-blocking while entry is loading", func(t *testing.T) { + t.Parallel() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + loadStarted := make(chan struct{}) + loadContinue := make(chan struct{}) + + // Start a load operation that will block + go func() { + _, err := FnCacheGet(context.Background(), cache, "loading-key", func(ctx context.Context) (string, error) { + close(loadStarted) + <-loadContinue + return "loaded-value", nil + }) + assert.NoError(t, err) + }() + + <-loadStarted + + // GetIfExists should return immediately with false, not block + val, ok := cache.GetIfExists("loading-key") + require.False(t, ok) + require.Empty(t, val) + + close(loadContinue) + + // Now it should be available + require.Eventually(t, func() bool { + val, ok := cache.GetIfExists("loading-key") + return ok && val == "loaded-value" + }, time.Second, 10*time.Millisecond) + }) + + t.Run("concurrent GetIfExists Remove/Set on same key", func(t *testing.T) { + t.Parallel() + cache, err := NewFnCache(FnCacheConfig{ + TTL: time.Hour, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(3) + go func() { + defer wg.Done() + cache.GetIfExists("key") + }() + go func() { + defer wg.Done() + cache.Remove("key") + }() + + go func() { + defer wg.Done() + cache.Set("key", "value") + }() + } + wg.Wait() + }) +}