Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import com.azure.core.amqp.exception.AmqpResponseCode;
import com.azure.core.exception.AzureException;
import com.azure.core.util.logging.ClientLogger;
import reactor.core.Disposable;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.ReplayProcessor;

import java.time.Duration;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

/**
* Manages the re-authorization of the client to the token audience against the CBS node.
Expand All @@ -30,27 +31,19 @@ public class ActiveClientTokenManager implements TokenManager {
private final Mono<ClaimsBasedSecurityNode> cbsNode;
private final String tokenAudience;
private final String scopes;
private final Timer timer;
private final Flux<AmqpResponseCode> authorizationResults;
private FluxSink<AmqpResponseCode> sink;
private final ReplayProcessor<AmqpResponseCode> authorizationResults = ReplayProcessor.create(1);
private final FluxSink<AmqpResponseCode> authorizationResultsSink =
authorizationResults.sink(FluxSink.OverflowStrategy.BUFFER);
private final EmitterProcessor<Duration> durationSource = EmitterProcessor.create();
private final FluxSink<Duration> durationSourceSink = durationSource.sink();
private final AtomicReference<Duration> lastRefreshInterval = new AtomicReference<>(Duration.ofMinutes(1));

// last refresh interval in milliseconds.
private AtomicLong lastRefreshInterval = new AtomicLong();
private volatile Disposable subscription;

public ActiveClientTokenManager(Mono<ClaimsBasedSecurityNode> cbsNode, String tokenAudience, String scopes) {
this.timer = new Timer(tokenAudience + "-tokenManager");
this.cbsNode = cbsNode;
this.tokenAudience = tokenAudience;
this.scopes = scopes;
this.authorizationResults = Flux.create(sink -> {
if (hasDisposed.get()) {
sink.complete();
} else {
this.sink = sink;
}
});

lastRefreshInterval.set(Duration.ofMinutes(1).getSeconds() * 1000);
}

/**
Expand Down Expand Up @@ -82,15 +75,18 @@ public Mono<Long> authorize() {

// We want to refresh the token when 90% of the time before expiry has elapsed.
final long refreshSeconds = (long) Math.floor(between.getSeconds() * 0.9);

// This converts it to milliseconds
final long refreshIntervalMS = refreshSeconds * 1000;

lastRefreshInterval.set(refreshIntervalMS);

// If this is the first time authorize is called, the task will not have been scheduled yet.
if (!hasScheduled.getAndSet(true)) {
logger.info("Scheduling refresh token task.");
scheduleRefreshTokenTask(refreshIntervalMS);
logger.info("Scheduling refresh token task");

final Duration firstInterval = Duration.ofMillis(refreshIntervalMS);
lastRefreshInterval.set(firstInterval);
authorizationResultsSink.next(AmqpResponseCode.ACCEPTED);
subscription = scheduleRefreshTokenTask(firstInterval);
}

return refreshIntervalMS;
Expand All @@ -99,52 +95,51 @@ public Mono<Long> authorize() {

@Override
public void close() {
if (!hasDisposed.getAndSet(true)) {
if (this.sink != null) {
this.sink.complete();
}

this.timer.cancel();
if (hasDisposed.getAndSet(true)) {
return;
}
}

private void scheduleRefreshTokenTask(Long refreshIntervalInMS) {
try {
timer.schedule(new RefreshAuthorizationToken(), refreshIntervalInMS);
} catch (IllegalStateException e) {
logger.warning("Unable to schedule RefreshAuthorizationToken task.", e);
hasScheduled.set(false);
authorizationResultsSink.complete();
durationSourceSink.complete();

if (subscription != null) {
subscription.dispose();
}
}

private class RefreshAuthorizationToken extends TimerTask {
@Override
public void run() {
logger.info("Refreshing authorization token.");
authorize().subscribe(
(Long refreshIntervalInMS) -> {

if (hasDisposed.get()) {
logger.info("Token manager has been disposed of. Not rescheduling.");
return;
}

logger.info("Authorization successful. Refreshing token in {} ms.", refreshIntervalInMS);
sink.next(AmqpResponseCode.ACCEPTED);

scheduleRefreshTokenTask(refreshIntervalInMS);
}, error -> {
if ((error instanceof AmqpException) && ((AmqpException) error).isTransient()) {
logger.error("Error is transient. Rescheduling authorization task.", error);
scheduleRefreshTokenTask(lastRefreshInterval.get());
} else {
logger.error("Error occurred while refreshing token that is not retriable. Not scheduling"
+ " refresh task. Use ActiveClientTokenManager.authorize() to schedule task again.", error);
hasScheduled.set(false);
}

sink.error(error);
private Disposable scheduleRefreshTokenTask(Duration initialRefresh) {
// EmitterProcessor can queue up an initial refresh interval before any subscribers are received.
durationSourceSink.next(initialRefresh);

return Flux.switchOnNext(durationSource.map(Flux::interval))
.flatMap(delay -> {
logger.info("Refreshing token.");
return authorize();
})
.onErrorContinue(
error -> (error instanceof AmqpException) && ((AmqpException) error).isTransient(),
(amqpException, interval) -> {
final Duration lastRefresh = lastRefreshInterval.get();

logger.error("Error is transient. Rescheduling authorization task at interval {} ms.",
lastRefresh.toMillis(), amqpException);
durationSourceSink.next(lastRefreshInterval.get());
})
.subscribe(interval -> {
logger.info("Authorization successful. Refreshing token in {} ms.", interval);
authorizationResultsSink.next(AmqpResponseCode.ACCEPTED);

final Duration nextRefresh = Duration.ofMillis(interval);
lastRefreshInterval.set(nextRefresh);
durationSourceSink.next(Duration.ofMillis(interval));
}, error -> {
logger.error("Error occurred while refreshing token that is not retriable. Not scheduling"
+ " refresh task. Use ActiveClientTokenManager.authorize() to schedule task again.", error);
hasScheduled.set(false);
durationSourceSink.complete();
authorizationResultsSink.error(error);
}, () -> {
logger.info("Completed refresh token task.");
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public AzureTokenManagerProvider(CbsAuthorizationType authorizationType, String
public TokenManager getTokenManager(Mono<ClaimsBasedSecurityNode> cbsNodeMono, String resource) {
final String scopes = getResourceString(resource);
final String tokenAudience = String.format(Locale.US, TOKEN_AUDIENCE_FORMAT, fullyQualifiedNamespace, resource);

logger.info("Creating new token manager for audience[{}], scopes[{}]", tokenAudience, scopes);
return new ActiveClientTokenManager(cbsNodeMono, tokenAudience, scopes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.azure.core.amqp.exception.AmqpResponseCode;
import com.azure.core.exception.AzureException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
Expand All @@ -25,7 +26,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

public class ActiveClientTokenManagerTest {
class ActiveClientTokenManagerTest {
private static final String AUDIENCE = "an-audience-test";
private static final String SCOPES = "scopes-test";
private static final Duration TIMEOUT = Duration.ofSeconds(4);
Expand All @@ -34,12 +35,12 @@ public class ActiveClientTokenManagerTest {
private ClaimsBasedSecurityNode cbsNode;

@BeforeEach
public void setup() {
void setup() {
MockitoAnnotations.initMocks(this);
}

@AfterEach
public void teardown() {
void teardown() {
Mockito.framework().clearInlineMocks();
cbsNode = null;
}
Expand All @@ -48,7 +49,7 @@ public void teardown() {
* Verify that we can get successes and errors from CBS node.
*/
@Test
public void getAuthorizationResults() {
void getAuthorizationResults() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(3));
Expand All @@ -60,8 +61,9 @@ public void getAuthorizationResults() {
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.then(tokenManager::close)
.verifyComplete();
.then(() -> tokenManager.close())
.expectComplete()
.verify();
}

/**
Expand All @@ -70,7 +72,7 @@ public void getAuthorizationResults() {
*/
@SuppressWarnings("unchecked")
@Test
public void getAuthorizationResultsSuccessFailure() {
void getAuthorizationResultsSuccessFailure() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
final IllegalArgumentException error = new IllegalArgumentException("Some error");
Expand All @@ -83,6 +85,7 @@ public void getAuthorizationResultsSuccessFailure() {
StepVerifier.create(tokenManager.getAuthorizationResults())
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.expectError(IllegalArgumentException.class)
.verifyThenAssertThat()
.hasNotDroppedElements()
Expand All @@ -95,7 +98,7 @@ public void getAuthorizationResultsSuccessFailure() {
* Verify that we cannot authorize with CBS node when it has already been disposed of.
*/
@Test
public void cannotAuthorizeDisposedInstance() {
void cannotAuthorizeDisposedInstance() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(2));
Expand All @@ -114,31 +117,64 @@ public void cannotAuthorizeDisposedInstance() {
*/
@SuppressWarnings("unchecked")
@Test
public void getAuthorizationResultsRetriableError() {
void getAuthorizationResultsRetriableError() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
final AmqpException error = new AmqpException(true, AmqpErrorCondition.TIMEOUT_ERROR, "Timed out",
final AmqpException error = new AmqpException(false, AmqpErrorCondition.ARGUMENT_ERROR,
"Non-retryable argument error",
new AmqpErrorContext("Test-context-namespace"));

when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(3), Mono.error(error),
getNextExpiration(5), getNextExpiration(10),
getNextExpiration(45));
when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(5), Mono.error(error),
getNextExpiration(5));

// Act & Assert
try (ActiveClientTokenManager tokenManager = new ActiveClientTokenManager(cbsNodeMono, AUDIENCE, SCOPES)) {
StepVerifier.create(tokenManager.getAuthorizationResults())
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectError(AmqpException.class)
.verify();

StepVerifier.create(tokenManager.getAuthorizationResults())
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.then(tokenManager::close)
.verifyComplete();
.expectErrorSatisfies(exception -> {
Assertions.assertTrue(exception instanceof AmqpException);

AmqpException amqpException = (AmqpException) exception;
Assertions.assertFalse(amqpException.isTransient());
Assertions.assertEquals(error.getErrorCondition(), amqpException.getErrorCondition());
})
.verify(Duration.ofSeconds(30));
}
}


/**
* Verify that the ActiveClientTokenManager does not get more authorization tasks.
*/
@SuppressWarnings("unchecked")
@Test
void getAuthorizationResultsNonRetriableError() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
final AmqpException error = new AmqpException(true, AmqpErrorCondition.TIMEOUT_ERROR, "Test CBS node error.",
new AmqpErrorContext("Test-context-namespace"));

when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(5), Mono.error(error),
getNextExpiration(5), getNextExpiration(10),
getNextExpiration(45));

// Act & Assert
final ActiveClientTokenManager tokenManager = new ActiveClientTokenManager(cbsNodeMono, AUDIENCE, SCOPES);

StepVerifier.create(tokenManager.getAuthorizationResults())
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.then(() -> {
System.out.println("Closing");
tokenManager.close();
})
.expectComplete()
.verify(Duration.ofSeconds(30));
}


private Mono<OffsetDateTime> getNextExpiration(long secondsToWait) {
return Mono.fromCallable(() -> OffsetDateTime.now(ZoneOffset.UTC).plusSeconds(secondsToWait));
}
Expand Down
Loading