Skip to content

Commit

Permalink
Add CAE Support. (#18785)
Browse files Browse the repository at this point in the history
  • Loading branch information
g2vinay authored Feb 5, 2021
1 parent e5fd329 commit 6dfa793
Show file tree
Hide file tree
Showing 9 changed files with 667 additions and 26 deletions.
5 changes: 4 additions & 1 deletion sdk/core/azure-core-experimental/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Release History

## 1.0.0-beta.10 (Unreleased)
## 1.0.0-beta.10 (2021-02-05)

### New Features

- Added challenge based authentication support via `BearerTokenAuthenticationChallengePolicy` and `AccessTokenCache` classes.

## 1.0.0-beta.9 (2021-01-11)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.experimental.credential;

import com.azure.core.credential.AccessToken;
import com.azure.core.util.logging.ClientLogger;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.publisher.Signal;

import java.time.Duration;
import java.time.OffsetDateTime;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;

/**
* A token cache that supports caching a token and refreshing it.
*/
public class AccessTokenCache {
// The delay after a refresh to attempt another token refresh
private static final Duration REFRESH_DELAY = Duration.ofSeconds(30);
// the offset before token expiry to attempt proactive token refresh
private static final Duration REFRESH_OFFSET = Duration.ofMinutes(5);
private volatile AccessToken cache;
private volatile OffsetDateTime nextTokenRefresh = OffsetDateTime.now();
private final AtomicReference<MonoProcessor<AccessToken>> wip;
private final Supplier<Mono<AccessToken>> tokenSupplier;
private final Predicate<AccessToken> shouldRefresh;
private final ClientLogger logger = new ClientLogger(AccessTokenCache.class);

/**
* Creates an instance of AccessTokenCache with default scheme "Bearer".
*
* @param tokenSupplier a method to get a new token
*/
public AccessTokenCache(Supplier<Mono<AccessToken>> tokenSupplier) {
Objects.requireNonNull(tokenSupplier, "The token supplier cannot be null");
this.wip = new AtomicReference<>();
this.tokenSupplier = tokenSupplier;
this.shouldRefresh = accessToken -> OffsetDateTime.now()
.isAfter(accessToken.getExpiresAt().minus(REFRESH_OFFSET));
}

/**
* Asynchronously get a token from either the cache or replenish the cache with a new token.
* @return a Publisher that emits an AccessToken
*/
public Mono<AccessToken> getToken() {
return getToken(this.tokenSupplier, false);
}

/**
* Asynchronously get a token from either the cache or replenish the cache with a new token.
*
* @param tokenSupplier The method to get a new token
* @param forceRefresh The flag indicating if the cache needs to be skipped and a token needs to be fetched via the
* credential.
* @return The Publisher that emits an AccessToken
*/
public Mono<AccessToken> getToken(Supplier<Mono<AccessToken>> tokenSupplier, boolean forceRefresh) {
return Mono.defer(retrieveToken(tokenSupplier, forceRefresh))
// Keep resubscribing as long as Mono.defer [token acquisition] emits empty().
.repeatWhenEmpty((Flux<Long> longFlux) -> longFlux.concatMap(ignored -> Flux.just(true)));
}

private Supplier<Mono<? extends AccessToken>> retrieveToken(Supplier<Mono<AccessToken>> tokenSupplier,
boolean forceRefresh) {
return () -> {
try {
if (wip.compareAndSet(null, MonoProcessor.create())) {
final MonoProcessor<AccessToken> monoProcessor = wip.get();
OffsetDateTime now = OffsetDateTime.now();
Mono<AccessToken> tokenRefresh;
Mono<AccessToken> fallback;
if (forceRefresh) {
tokenRefresh = Mono.defer(tokenSupplier);
fallback = Mono.empty();
} else if (cache != null && !shouldRefresh.test(cache)) {
// fresh cache & no need to refresh
tokenRefresh = Mono.empty();
fallback = Mono.just(cache);
} else if (cache == null || cache.isExpired()) {
// no token to use
if (now.isAfter(nextTokenRefresh)) {
// refresh immediately
tokenRefresh = Mono.defer(tokenSupplier);
} else {
// wait for timeout, then refresh
tokenRefresh = Mono.defer(tokenSupplier)
.delaySubscription(Duration.between(now, nextTokenRefresh));
}
// cache doesn't exist or expired, no fallback
fallback = Mono.empty();
} else {
// token available, but close to expiry
if (now.isAfter(nextTokenRefresh)) {
// refresh immediately
tokenRefresh = Mono.defer(tokenSupplier);
} else {
// still in timeout, do not refresh
tokenRefresh = Mono.empty();
}
// cache hasn't expired, ignore refresh error this time
fallback = Mono.just(cache);
}
return tokenRefresh
.materialize()
.flatMap(processTokenRefreshResult(monoProcessor, now, fallback))
.doOnError(monoProcessor::onError)
.doFinally(ignored -> wip.set(null));
} else if (cache != null && !cache.isExpired() && !forceRefresh) {
// another thread might be refreshing the token proactively, but the current token is still valid
return Mono.just(cache);
} else {
// another thread is definitely refreshing the expired token
//If this thread, needs to force refresh, then it needs to resubscribe.
if (forceRefresh) {
return Mono.empty();
}
MonoProcessor<AccessToken> monoProcessor = wip.get();
if (monoProcessor == null) {
// the refreshing thread has finished
return Mono.just(cache);
} else {
// wait for refreshing thread to finish but defer to updated cache in case just missed onNext()
return monoProcessor.switchIfEmpty(Mono.defer(() -> Mono.just(cache)));
}
}
} catch (Throwable t) {
return Mono.error(t);
}
};
}

private Function<Signal<AccessToken>, Mono<? extends AccessToken>> processTokenRefreshResult(
MonoProcessor<AccessToken> monoProcessor, OffsetDateTime now, Mono<AccessToken> fallback) {
return signal -> {
AccessToken accessToken = signal.get();
Throwable error = signal.getThrowable();
if (signal.isOnNext() && accessToken != null) { // SUCCESS
logger.info(refreshLog(cache, now, "Acquired a new access token"));
cache = accessToken;
monoProcessor.onNext(accessToken);
monoProcessor.onComplete();
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
return Mono.just(accessToken);
} else if (signal.isOnError() && error != null) { // ERROR
logger.error(refreshLog(cache, now, "Failed to acquire a new access token"));
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
return fallback.switchIfEmpty(Mono.error(error));
} else { // NO REFRESH
monoProcessor.onComplete();
return fallback;
}
};
}

private static String refreshLog(AccessToken cache, OffsetDateTime now, String log) {
StringBuilder info = new StringBuilder(log);
if (cache == null) {
info.append(".");
} else {
Duration tte = Duration.between(now, cache.getExpiresAt());
info.append(" at ").append(tte.abs().getSeconds()).append(" seconds ")
.append(tte.isNegative() ? "after" : "before").append(" expiry. ")
.append("Retry may be attempted after ").append(REFRESH_DELAY.getSeconds()).append(" seconds.");
if (!tte.isNegative()) {
info.append(" The token currently cached will be used.");
}
}
return info.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

/**
* Package containing experimental credential classes for authentication purposes.
*/
package com.azure.core.experimental.credential;
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.experimental.http.policy;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.core.experimental.credential.AccessTokenCache;
import com.azure.core.experimental.implementation.AuthenticationChallenge;
import com.azure.core.http.HttpPipelineCallContext;
import com.azure.core.http.HttpPipelineNextPolicy;
import com.azure.core.http.HttpResponse;
import com.azure.core.http.policy.HttpPipelinePolicy;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* The pipeline policy that applies a token credential to an HTTP request
* with "Bearer" scheme.
*/
public class BearerTokenAuthenticationChallengePolicy implements HttpPipelinePolicy {
private static final String AUTHORIZATION_HEADER = "Authorization";
private static final String BEARER = "Bearer";
public static final Pattern AUTHENTICATION_CHALLENGE_PATTERN =
Pattern.compile("(\\w+) ((?:\\w+=\".*?\"(?:, )?)+)(?:, )?");
public static final Pattern AUTHENTICATION_CHALLENGE_PARAMS_PATTERN =
Pattern.compile("(?:(\\w+)=\"([^\"\"]*)\")+");
public static final String WWW_AUTHENTICATE = "WWW-Authenticate";
public static final String CLAIMS_PARAMETER = "claims";

private final TokenCredential credential;
private final String[] scopes;
private final Supplier<Mono<AccessToken>> defaultTokenSupplier;
private final AccessTokenCache cache;

/**
* Creates BearerTokenAuthenticationChallengePolicy.
*
* @param credential the token credential to authenticate the request
* @param scopes the scopes of authentication the credential should get token for
*/
public BearerTokenAuthenticationChallengePolicy(TokenCredential credential, String... scopes) {
Objects.requireNonNull(credential);
this.credential = credential;
this.scopes = scopes;
this.defaultTokenSupplier = () -> credential.getToken(new TokenRequestContext().addScopes(scopes));
this.cache = new AccessTokenCache(defaultTokenSupplier);
}

/**
*
* Executed before sending the initial request and authenticates the request.
*
* @param context The request context.
* @return A {@link Mono} containing {@link Void}
*/
public Mono<Void> onBeforeRequest(HttpPipelineCallContext context) {
return authenticateRequest(context, defaultTokenSupplier, false);
}

/**
* Handles the authentication challenge in the event a 401 response with a WWW-Authenticate authentication
* challenge header is received after the initial request.
*
* @param context The request context.
* @param response The Http Response containing the authentication challenge header.
* @return A {@link Mono} containing the status, whether the challenge was successfully extracted and handled.
* if true then a follow up request needs to be sent authorized with the challenge based bearer token.
*/
public Mono<Boolean> onChallenge(HttpPipelineCallContext context, HttpResponse response) {
String authHeader = response.getHeaderValue(WWW_AUTHENTICATE);
if (response.getStatusCode() == 401 && authHeader != null) {
List<AuthenticationChallenge> challenges = parseChallenges(authHeader);
for (AuthenticationChallenge authenticationChallenge : challenges) {
Map<String, String> extractedChallengeParams =
parseChallengeParams(authenticationChallenge.getChallengeParameters());
if (extractedChallengeParams.containsKey(CLAIMS_PARAMETER)) {
String claims = new String(Base64.getUrlDecoder()
.decode(extractedChallengeParams.get(CLAIMS_PARAMETER)), StandardCharsets.UTF_8);
return authenticateRequest(context,
() -> credential.getToken(new TokenRequestContext()
.addScopes(scopes).setClaims(claims)), true)
.flatMap(b -> Mono.just(true));
}
}
}
return Mono.just(false);
}

@Override
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
if ("http".equals(context.getHttpRequest().getUrl().getProtocol())) {
return Mono.error(new RuntimeException("token credentials require a URL using the HTTPS protocol scheme"));
}
HttpPipelineNextPolicy nextPolicy = next.clone();

return onBeforeRequest(context)
.then(next.process())
.flatMap(httpResponse -> {
String authHeader = httpResponse.getHeaderValue(WWW_AUTHENTICATE);
if (httpResponse.getStatusCode() == 401 && authHeader != null) {
return onChallenge(context, httpResponse).flatMap(retry -> {
if (retry) {
return nextPolicy.process();
} else {
return Mono.just(httpResponse);
}
});
}
return Mono.just(httpResponse);
});
}

/**
* Get the {@link AccessTokenCache} holding the cached access tokens and the logic to retrieve and refresh
* access tokens.
*
* @return the {@link AccessTokenCache}
*/
public AccessTokenCache getTokenCache() {
return cache;
}

private Mono<Void> authenticateRequest(HttpPipelineCallContext context, Supplier<Mono<AccessToken>> tokenSupplier,
boolean forceTokenRefresh) {
return cache.getToken(tokenSupplier, forceTokenRefresh)
.flatMap(token -> {
context.getHttpRequest().getHeaders().set(AUTHORIZATION_HEADER, BEARER + " " + token.getToken());
return Mono.empty();
});
}

List<AuthenticationChallenge> parseChallenges(String header) {
Matcher matcher = AUTHENTICATION_CHALLENGE_PATTERN.matcher(header);

List<AuthenticationChallenge> challenges = new ArrayList<>();
while (matcher.find()) {
challenges.add(new AuthenticationChallenge(matcher.group(1), matcher.group(2)));
}

return challenges;
}

Map<String, String> parseChallengeParams(String challengeParams) {
Matcher matcher = AUTHENTICATION_CHALLENGE_PARAMS_PATTERN.matcher(challengeParams);

Map<String, String> challengeParameters = new HashMap<>();
while (matcher.find()) {
challengeParameters.put(matcher.group(1), matcher.group(2));
}
return challengeParameters;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

/**
* Package containing experimental http policies.
*/
package com.azure.core.experimental.http.policy;
Loading

0 comments on commit 6dfa793

Please sign in to comment.