diff --git a/lib/trino-collect/src/main/java/io/trino/collect/cache/EvictableCache.java b/lib/trino-collect/src/main/java/io/trino/collect/cache/EvictableCache.java index 0ea1de7965cc..52d0943fe08c 100644 --- a/lib/trino-collect/src/main/java/io/trino/collect/cache/EvictableCache.java +++ b/lib/trino-collect/src/main/java/io/trino/collect/cache/EvictableCache.java @@ -77,7 +77,7 @@ class EvictableCache tokens.remove(token.getKey(), token); } }), - new TokenCacheLoader<>(cacheLoader)); + new TokenCacheLoader<>(cacheLoader, tokens)); } @SuppressModernizer // CacheBuilder.build(CacheLoader) is forbidden, advising to use this class as a safety-adding wrapper. @@ -104,8 +104,13 @@ public V get(K key, Callable valueLoader) { Token newToken = new Token<>(key); Token token = tokens.computeIfAbsent(key, ignored -> newToken); + Callable valueLoaderImpl = () -> { + // revive token if it got expired before reloading + tokens.computeIfAbsent(token.getKey(), ignored -> token); + return valueLoader.call(); + }; try { - return dataCache.get(token, valueLoader); + return dataCache.get(token, valueLoaderImpl); } catch (Throwable e) { if (newToken == token) { @@ -394,16 +399,20 @@ private static class TokenCacheLoader extends CacheLoader, V> { private final CacheLoader delegate; + private final ConcurrentHashMap> tokens; - public TokenCacheLoader(CacheLoader delegate) + public TokenCacheLoader(CacheLoader delegate, ConcurrentHashMap> tokens) { this.delegate = requireNonNull(delegate, "delegate is null"); + this.tokens = requireNonNull(tokens, "tokens is null"); } @Override public V load(Token token) throws Exception { + // revive token if it got expired before reloading + tokens.computeIfAbsent(token.getKey(), ignored -> token); return delegate.load(token.getKey()); } diff --git a/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableCache.java b/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableCache.java index 79f03ac6f853..16498d8155ff 100644 --- a/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableCache.java +++ b/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableCache.java @@ -23,6 +23,7 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -43,8 +44,10 @@ import static io.trino.collect.cache.CacheStatsAssertions.assertCacheStats; import static java.lang.Math.toIntExact; import static java.lang.String.format; +import static java.time.temporal.ChronoUnit.MILLIS; import static java.util.concurrent.Executors.newFixedThreadPool; import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -108,6 +111,26 @@ public void testEvictByWeight() assertThat(cache.asMap().values().stream().mapToInt(String::length).sum()).as("values length sum").isLessThanOrEqualTo(20); } + @Test(timeOut = TEST_TIMEOUT_MILLIS) + public void testLoadOnceWithTimeEviction() + throws Exception + { + int ttl = 50; + int expectedCalls = 10; + + AtomicInteger counter = new AtomicInteger(); + Cache cache = EvictableCacheBuilder.newBuilder() + .expireAfterWrite(ttl, MILLISECONDS) + .build(); + + Instant until = Instant.now().plus(ttl * expectedCalls, MILLIS); + while (until.isAfter(Instant.now())) { + cache.get("foo", counter::incrementAndGet); + Thread.sleep(1); + } + assertThat(counter.get()).isEqualTo(expectedCalls); + } + @Test(timeOut = TEST_TIMEOUT_MILLIS) public void testReplace() throws Exception diff --git a/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableLoadingCache.java b/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableLoadingCache.java index 7563455a13f1..616ca4741f2e 100644 --- a/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableLoadingCache.java +++ b/lib/trino-collect/src/test/java/io/trino/collect/cache/TestEvictableLoadingCache.java @@ -24,6 +24,7 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -47,7 +48,9 @@ import static io.trino.collect.cache.CacheStatsAssertions.assertCacheStats; import static java.lang.Math.toIntExact; import static java.lang.String.format; +import static java.time.temporal.ChronoUnit.MILLIS; import static java.util.concurrent.Executors.newFixedThreadPool; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -112,6 +115,29 @@ public void testEvictByWeight() assertThat(cache.asMap().values().stream().mapToInt(String::length).sum()).as("values length sum").isLessThanOrEqualTo(20); } + @Test(timeOut = TEST_TIMEOUT_MILLIS) + public void testLoadOnceWithTimeEviction() + throws Exception + { + int ttl = 50; + int expectedCalls = 10; + + AtomicInteger counter = new AtomicInteger(); + LoadingCache cache = EvictableCacheBuilder.newBuilder() + .expireAfterWrite(ttl, MILLISECONDS) + .build(CacheLoader.from(k -> { + counter.incrementAndGet(); + return k; + })); + + Instant until = Instant.now().plus(ttl * expectedCalls, MILLIS); + while (until.isAfter(Instant.now())) { + cache.get("foo"); + Thread.sleep(1); + } + assertThat(counter.get()).isEqualTo(expectedCalls); + } + @Test(timeOut = TEST_TIMEOUT_MILLIS, dataProvider = "testDisabledCacheDataProvider") public void testDisabledCache(String behavior) throws Exception