Skip to content

Commit

Permalink
Add Refresh On Support in SimpleTokenCache (#41315)
Browse files Browse the repository at this point in the history
  • Loading branch information
g2vinay authored Jul 31, 2024
1 parent 681a969 commit fee3576
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@
* @see com.azure.core.credential.TokenCredential
*/
public class SimpleTokenCache {
// The delay after a refresh to attempt another token refresh
private static final Duration REFRESH_DELAY = Duration.ofSeconds(30);
private static final String REFRESH_DELAY_STRING = String.valueOf(REFRESH_DELAY.getSeconds());

// the offset before token expiry to attempt proactive token refresh
private static final Duration REFRESH_OFFSET = Duration.ofMinutes(5);
// SimpleTokenCache is commonly used, use a static logger.
Expand All @@ -77,16 +75,29 @@ public class SimpleTokenCache {
private final Supplier<Mono<AccessToken>> tokenSupplier;
private final Predicate<AccessToken> shouldRefresh;

// The delay after a refresh to attempt another token refresh
private final Duration refreshDelay;
private final String refreshDelayString;

/**
* Creates an instance of RefreshableTokenCredential with default scheme "Bearer".
*
* @param tokenSupplier a method to get a new token
*/
public SimpleTokenCache(Supplier<Mono<AccessToken>> tokenSupplier) {
this(tokenSupplier, Duration.ofSeconds(30));

}

SimpleTokenCache(Supplier<Mono<AccessToken>> tokenSupplier, Duration refreshDelay) {
this.wip = new AtomicReference<>();
this.tokenSupplier = tokenSupplier;
this.shouldRefresh
= accessToken -> OffsetDateTime.now().isAfter(accessToken.getExpiresAt().minus(REFRESH_OFFSET));
this.shouldRefresh = accessToken -> OffsetDateTime.now()
.isAfter(accessToken.getRefreshAt() == null
? accessToken.getExpiresAt().minus(REFRESH_OFFSET)
: accessToken.getRefreshAt());
this.refreshDelay = refreshDelay;
this.refreshDelayString = String.valueOf(refreshDelay.getSeconds());
}

/**
Expand Down Expand Up @@ -137,12 +148,12 @@ public Mono<AccessToken> getToken() {
buildTokenRefreshLog(LogLevel.INFORMATIONAL, cache, now).log("Acquired a new access token");
cache = accessToken;
sinksOne.tryEmitValue(accessToken);
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
nextTokenRefresh = OffsetDateTime.now().plus(refreshDelay);
return Mono.just(accessToken);
} else if (signal.isOnError() && error != null) { // ERROR
buildTokenRefreshLog(LogLevel.ERROR, cache, now)
.log("Failed to acquire a new access token");
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
nextTokenRefresh = OffsetDateTime.now().plus(refreshDelay);
return fallback.switchIfEmpty(Mono.error(() -> error));
} else { // NO REFRESH
sinksOne.tryEmitEmpty();
Expand Down Expand Up @@ -173,7 +184,7 @@ Sinks.One<AccessToken> getWipValue() {
return wip.get();
}

private static LoggingEventBuilder buildTokenRefreshLog(LogLevel level, AccessToken cache, OffsetDateTime now) {
private LoggingEventBuilder buildTokenRefreshLog(LogLevel level, AccessToken cache, OffsetDateTime now) {
LoggingEventBuilder logBuilder = LOGGER.atLevel(level);
if (cache == null || !LOGGER.canLogAtLevel(level)) {
return logBuilder;
Expand All @@ -182,7 +193,7 @@ private static LoggingEventBuilder buildTokenRefreshLog(LogLevel level, AccessTo
Duration tte = Duration.between(now, cache.getExpiresAt());
return logBuilder.addKeyValue("expiresAt", cache.getExpiresAt())
.addKeyValue("tteSeconds", String.valueOf(tte.abs().getSeconds()))
.addKeyValue("retryAfterSeconds", REFRESH_DELAY_STRING)
.addKeyValue("retryAfterSeconds", refreshDelayString)
.addKeyValue("expired", tte.isNegative());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@

import java.time.Duration;
import java.time.OffsetDateTime;
import java.util.concurrent.atomic.AtomicLong;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;

/**
* Tests {@link SimpleTokenCache}.
*/
public class SimpleTokenCacheTests {

@Test
public void wipResetsOnCancel() {
SimpleTokenCache simpleTokenCache
Expand All @@ -31,4 +34,44 @@ public void wipResetsOnCancel() {

assertNull(simpleTokenCache.getWipValue());
}

@Test
public void testRefreshOnFlow() throws InterruptedException {
AtomicLong refreshes = new AtomicLong(0);

TokenCredential dummyCred = request -> {
refreshes.incrementAndGet();
return Mono.just(new TokenCacheTests.Token("testToken", 30000, 1000));
};

SimpleTokenCache cache
= new SimpleTokenCache(() -> dummyCred.getToken(new TokenRequestContext()), Duration.ofSeconds(0));

StepVerifier.create(cache.getToken().delayElement(Duration.ofMillis(2000)).flatMap(ignored -> cache.getToken()))
.assertNext(token -> {
assertEquals("testToken", token.getToken());
assertEquals(2, refreshes.get());
})
.verifyComplete();
}

@Test
public void testDoNotRefreshOnFlow() throws InterruptedException {
AtomicLong refreshes = new AtomicLong(0);

TokenCredential dummyCred = request -> {
refreshes.incrementAndGet();
return Mono.just(new TokenCacheTests.Token("testToken", 30000, 12000));
};

SimpleTokenCache cache
= new SimpleTokenCache(() -> dummyCred.getToken(new TokenRequestContext()), Duration.ofSeconds(1));

StepVerifier.create(cache.getToken().delayElement(Duration.ofMillis(2000)).flatMap(ignored -> cache.getToken()))
.assertNext(token -> {
assertEquals("testToken", token.getToken());
assertEquals(1, refreshes.get());
})
.verifyComplete();
}
}

0 comments on commit fee3576

Please sign in to comment.