-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
667 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
...e-experimental/src/main/java/com/azure/core/experimental/credential/AccessTokenCache.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package com.azure.core.experimental.credential; | ||
|
||
import com.azure.core.credential.AccessToken; | ||
import com.azure.core.util.logging.ClientLogger; | ||
import reactor.core.publisher.Flux; | ||
import reactor.core.publisher.Mono; | ||
import reactor.core.publisher.MonoProcessor; | ||
import reactor.core.publisher.Signal; | ||
|
||
import java.time.Duration; | ||
import java.time.OffsetDateTime; | ||
import java.util.Objects; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
import java.util.function.Function; | ||
import java.util.function.Predicate; | ||
import java.util.function.Supplier; | ||
|
||
/** | ||
* A token cache that supports caching a token and refreshing it. | ||
*/ | ||
public class AccessTokenCache { | ||
// The delay after a refresh to attempt another token refresh | ||
private static final Duration REFRESH_DELAY = Duration.ofSeconds(30); | ||
// the offset before token expiry to attempt proactive token refresh | ||
private static final Duration REFRESH_OFFSET = Duration.ofMinutes(5); | ||
private volatile AccessToken cache; | ||
private volatile OffsetDateTime nextTokenRefresh = OffsetDateTime.now(); | ||
private final AtomicReference<MonoProcessor<AccessToken>> wip; | ||
private final Supplier<Mono<AccessToken>> tokenSupplier; | ||
private final Predicate<AccessToken> shouldRefresh; | ||
private final ClientLogger logger = new ClientLogger(AccessTokenCache.class); | ||
|
||
/** | ||
* Creates an instance of AccessTokenCache with default scheme "Bearer". | ||
* | ||
* @param tokenSupplier a method to get a new token | ||
*/ | ||
public AccessTokenCache(Supplier<Mono<AccessToken>> tokenSupplier) { | ||
Objects.requireNonNull(tokenSupplier, "The token supplier cannot be null"); | ||
this.wip = new AtomicReference<>(); | ||
this.tokenSupplier = tokenSupplier; | ||
this.shouldRefresh = accessToken -> OffsetDateTime.now() | ||
.isAfter(accessToken.getExpiresAt().minus(REFRESH_OFFSET)); | ||
} | ||
|
||
/** | ||
* Asynchronously get a token from either the cache or replenish the cache with a new token. | ||
* @return a Publisher that emits an AccessToken | ||
*/ | ||
public Mono<AccessToken> getToken() { | ||
return getToken(this.tokenSupplier, false); | ||
} | ||
|
||
/** | ||
* Asynchronously get a token from either the cache or replenish the cache with a new token. | ||
* | ||
* @param tokenSupplier The method to get a new token | ||
* @param forceRefresh The flag indicating if the cache needs to be skipped and a token needs to be fetched via the | ||
* credential. | ||
* @return The Publisher that emits an AccessToken | ||
*/ | ||
public Mono<AccessToken> getToken(Supplier<Mono<AccessToken>> tokenSupplier, boolean forceRefresh) { | ||
return Mono.defer(retrieveToken(tokenSupplier, forceRefresh)) | ||
// Keep resubscribing as long as Mono.defer [token acquisition] emits empty(). | ||
.repeatWhenEmpty((Flux<Long> longFlux) -> longFlux.concatMap(ignored -> Flux.just(true))); | ||
} | ||
|
||
private Supplier<Mono<? extends AccessToken>> retrieveToken(Supplier<Mono<AccessToken>> tokenSupplier, | ||
boolean forceRefresh) { | ||
return () -> { | ||
try { | ||
if (wip.compareAndSet(null, MonoProcessor.create())) { | ||
final MonoProcessor<AccessToken> monoProcessor = wip.get(); | ||
OffsetDateTime now = OffsetDateTime.now(); | ||
Mono<AccessToken> tokenRefresh; | ||
Mono<AccessToken> fallback; | ||
if (forceRefresh) { | ||
tokenRefresh = Mono.defer(tokenSupplier); | ||
fallback = Mono.empty(); | ||
} else if (cache != null && !shouldRefresh.test(cache)) { | ||
// fresh cache & no need to refresh | ||
tokenRefresh = Mono.empty(); | ||
fallback = Mono.just(cache); | ||
} else if (cache == null || cache.isExpired()) { | ||
// no token to use | ||
if (now.isAfter(nextTokenRefresh)) { | ||
// refresh immediately | ||
tokenRefresh = Mono.defer(tokenSupplier); | ||
} else { | ||
// wait for timeout, then refresh | ||
tokenRefresh = Mono.defer(tokenSupplier) | ||
.delaySubscription(Duration.between(now, nextTokenRefresh)); | ||
} | ||
// cache doesn't exist or expired, no fallback | ||
fallback = Mono.empty(); | ||
} else { | ||
// token available, but close to expiry | ||
if (now.isAfter(nextTokenRefresh)) { | ||
// refresh immediately | ||
tokenRefresh = Mono.defer(tokenSupplier); | ||
} else { | ||
// still in timeout, do not refresh | ||
tokenRefresh = Mono.empty(); | ||
} | ||
// cache hasn't expired, ignore refresh error this time | ||
fallback = Mono.just(cache); | ||
} | ||
return tokenRefresh | ||
.materialize() | ||
.flatMap(processTokenRefreshResult(monoProcessor, now, fallback)) | ||
.doOnError(monoProcessor::onError) | ||
.doFinally(ignored -> wip.set(null)); | ||
} else if (cache != null && !cache.isExpired() && !forceRefresh) { | ||
// another thread might be refreshing the token proactively, but the current token is still valid | ||
return Mono.just(cache); | ||
} else { | ||
// another thread is definitely refreshing the expired token | ||
//If this thread, needs to force refresh, then it needs to resubscribe. | ||
if (forceRefresh) { | ||
return Mono.empty(); | ||
} | ||
MonoProcessor<AccessToken> monoProcessor = wip.get(); | ||
if (monoProcessor == null) { | ||
// the refreshing thread has finished | ||
return Mono.just(cache); | ||
} else { | ||
// wait for refreshing thread to finish but defer to updated cache in case just missed onNext() | ||
return monoProcessor.switchIfEmpty(Mono.defer(() -> Mono.just(cache))); | ||
} | ||
} | ||
} catch (Throwable t) { | ||
return Mono.error(t); | ||
} | ||
}; | ||
} | ||
|
||
private Function<Signal<AccessToken>, Mono<? extends AccessToken>> processTokenRefreshResult( | ||
MonoProcessor<AccessToken> monoProcessor, OffsetDateTime now, Mono<AccessToken> fallback) { | ||
return signal -> { | ||
AccessToken accessToken = signal.get(); | ||
Throwable error = signal.getThrowable(); | ||
if (signal.isOnNext() && accessToken != null) { // SUCCESS | ||
logger.info(refreshLog(cache, now, "Acquired a new access token")); | ||
cache = accessToken; | ||
monoProcessor.onNext(accessToken); | ||
monoProcessor.onComplete(); | ||
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY); | ||
return Mono.just(accessToken); | ||
} else if (signal.isOnError() && error != null) { // ERROR | ||
logger.error(refreshLog(cache, now, "Failed to acquire a new access token")); | ||
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY); | ||
return fallback.switchIfEmpty(Mono.error(error)); | ||
} else { // NO REFRESH | ||
monoProcessor.onComplete(); | ||
return fallback; | ||
} | ||
}; | ||
} | ||
|
||
private static String refreshLog(AccessToken cache, OffsetDateTime now, String log) { | ||
StringBuilder info = new StringBuilder(log); | ||
if (cache == null) { | ||
info.append("."); | ||
} else { | ||
Duration tte = Duration.between(now, cache.getExpiresAt()); | ||
info.append(" at ").append(tte.abs().getSeconds()).append(" seconds ") | ||
.append(tte.isNegative() ? "after" : "before").append(" expiry. ") | ||
.append("Retry may be attempted after ").append(REFRESH_DELAY.getSeconds()).append(" seconds."); | ||
if (!tte.isNegative()) { | ||
info.append(" The token currently cached will be used."); | ||
} | ||
} | ||
return info.toString(); | ||
} | ||
} |
7 changes: 7 additions & 0 deletions
7
...-core-experimental/src/main/java/com/azure/core/experimental/credential/package-info.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
/** | ||
* Package containing experimental credential classes for authentication purposes. | ||
*/ | ||
package com.azure.core.experimental.credential; |
165 changes: 165 additions & 0 deletions
165
...ava/com/azure/core/experimental/http/policy/BearerTokenAuthenticationChallengePolicy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package com.azure.core.experimental.http.policy; | ||
|
||
import com.azure.core.credential.AccessToken; | ||
import com.azure.core.credential.TokenCredential; | ||
import com.azure.core.credential.TokenRequestContext; | ||
import com.azure.core.experimental.credential.AccessTokenCache; | ||
import com.azure.core.experimental.implementation.AuthenticationChallenge; | ||
import com.azure.core.http.HttpPipelineCallContext; | ||
import com.azure.core.http.HttpPipelineNextPolicy; | ||
import com.azure.core.http.HttpResponse; | ||
import com.azure.core.http.policy.HttpPipelinePolicy; | ||
import reactor.core.publisher.Mono; | ||
|
||
import java.nio.charset.StandardCharsets; | ||
import java.util.ArrayList; | ||
import java.util.Base64; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
/** | ||
* The pipeline policy that applies a token credential to an HTTP request | ||
* with "Bearer" scheme. | ||
*/ | ||
public class BearerTokenAuthenticationChallengePolicy implements HttpPipelinePolicy { | ||
private static final String AUTHORIZATION_HEADER = "Authorization"; | ||
private static final String BEARER = "Bearer"; | ||
public static final Pattern AUTHENTICATION_CHALLENGE_PATTERN = | ||
Pattern.compile("(\\w+) ((?:\\w+=\".*?\"(?:, )?)+)(?:, )?"); | ||
public static final Pattern AUTHENTICATION_CHALLENGE_PARAMS_PATTERN = | ||
Pattern.compile("(?:(\\w+)=\"([^\"\"]*)\")+"); | ||
public static final String WWW_AUTHENTICATE = "WWW-Authenticate"; | ||
public static final String CLAIMS_PARAMETER = "claims"; | ||
|
||
private final TokenCredential credential; | ||
private final String[] scopes; | ||
private final Supplier<Mono<AccessToken>> defaultTokenSupplier; | ||
private final AccessTokenCache cache; | ||
|
||
/** | ||
* Creates BearerTokenAuthenticationChallengePolicy. | ||
* | ||
* @param credential the token credential to authenticate the request | ||
* @param scopes the scopes of authentication the credential should get token for | ||
*/ | ||
public BearerTokenAuthenticationChallengePolicy(TokenCredential credential, String... scopes) { | ||
Objects.requireNonNull(credential); | ||
this.credential = credential; | ||
this.scopes = scopes; | ||
this.defaultTokenSupplier = () -> credential.getToken(new TokenRequestContext().addScopes(scopes)); | ||
this.cache = new AccessTokenCache(defaultTokenSupplier); | ||
} | ||
|
||
/** | ||
* | ||
* Executed before sending the initial request and authenticates the request. | ||
* | ||
* @param context The request context. | ||
* @return A {@link Mono} containing {@link Void} | ||
*/ | ||
public Mono<Void> onBeforeRequest(HttpPipelineCallContext context) { | ||
return authenticateRequest(context, defaultTokenSupplier, false); | ||
} | ||
|
||
/** | ||
* Handles the authentication challenge in the event a 401 response with a WWW-Authenticate authentication | ||
* challenge header is received after the initial request. | ||
* | ||
* @param context The request context. | ||
* @param response The Http Response containing the authentication challenge header. | ||
* @return A {@link Mono} containing the status, whether the challenge was successfully extracted and handled. | ||
* if true then a follow up request needs to be sent authorized with the challenge based bearer token. | ||
*/ | ||
public Mono<Boolean> onChallenge(HttpPipelineCallContext context, HttpResponse response) { | ||
String authHeader = response.getHeaderValue(WWW_AUTHENTICATE); | ||
if (response.getStatusCode() == 401 && authHeader != null) { | ||
List<AuthenticationChallenge> challenges = parseChallenges(authHeader); | ||
for (AuthenticationChallenge authenticationChallenge : challenges) { | ||
Map<String, String> extractedChallengeParams = | ||
parseChallengeParams(authenticationChallenge.getChallengeParameters()); | ||
if (extractedChallengeParams.containsKey(CLAIMS_PARAMETER)) { | ||
String claims = new String(Base64.getUrlDecoder() | ||
.decode(extractedChallengeParams.get(CLAIMS_PARAMETER)), StandardCharsets.UTF_8); | ||
return authenticateRequest(context, | ||
() -> credential.getToken(new TokenRequestContext() | ||
.addScopes(scopes).setClaims(claims)), true) | ||
.flatMap(b -> Mono.just(true)); | ||
} | ||
} | ||
} | ||
return Mono.just(false); | ||
} | ||
|
||
@Override | ||
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { | ||
if ("http".equals(context.getHttpRequest().getUrl().getProtocol())) { | ||
return Mono.error(new RuntimeException("token credentials require a URL using the HTTPS protocol scheme")); | ||
} | ||
HttpPipelineNextPolicy nextPolicy = next.clone(); | ||
|
||
return onBeforeRequest(context) | ||
.then(next.process()) | ||
.flatMap(httpResponse -> { | ||
String authHeader = httpResponse.getHeaderValue(WWW_AUTHENTICATE); | ||
if (httpResponse.getStatusCode() == 401 && authHeader != null) { | ||
return onChallenge(context, httpResponse).flatMap(retry -> { | ||
if (retry) { | ||
return nextPolicy.process(); | ||
} else { | ||
return Mono.just(httpResponse); | ||
} | ||
}); | ||
} | ||
return Mono.just(httpResponse); | ||
}); | ||
} | ||
|
||
/** | ||
* Get the {@link AccessTokenCache} holding the cached access tokens and the logic to retrieve and refresh | ||
* access tokens. | ||
* | ||
* @return the {@link AccessTokenCache} | ||
*/ | ||
public AccessTokenCache getTokenCache() { | ||
return cache; | ||
} | ||
|
||
private Mono<Void> authenticateRequest(HttpPipelineCallContext context, Supplier<Mono<AccessToken>> tokenSupplier, | ||
boolean forceTokenRefresh) { | ||
return cache.getToken(tokenSupplier, forceTokenRefresh) | ||
.flatMap(token -> { | ||
context.getHttpRequest().getHeaders().set(AUTHORIZATION_HEADER, BEARER + " " + token.getToken()); | ||
return Mono.empty(); | ||
}); | ||
} | ||
|
||
List<AuthenticationChallenge> parseChallenges(String header) { | ||
Matcher matcher = AUTHENTICATION_CHALLENGE_PATTERN.matcher(header); | ||
|
||
List<AuthenticationChallenge> challenges = new ArrayList<>(); | ||
while (matcher.find()) { | ||
challenges.add(new AuthenticationChallenge(matcher.group(1), matcher.group(2))); | ||
} | ||
|
||
return challenges; | ||
} | ||
|
||
Map<String, String> parseChallengeParams(String challengeParams) { | ||
Matcher matcher = AUTHENTICATION_CHALLENGE_PARAMS_PATTERN.matcher(challengeParams); | ||
|
||
Map<String, String> challengeParameters = new HashMap<>(); | ||
while (matcher.find()) { | ||
challengeParameters.put(matcher.group(1), matcher.group(2)); | ||
} | ||
return challengeParameters; | ||
} | ||
} | ||
|
7 changes: 7 additions & 0 deletions
7
...core-experimental/src/main/java/com/azure/core/experimental/http/policy/package-info.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
/** | ||
* Package containing experimental http policies. | ||
*/ | ||
package com.azure.core.experimental.http.policy; |
Oops, something went wrong.