Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public Mono<ShouldRetryResult> shouldRetry(Exception e) {
Exceptions.isSubStatusCode(clientException, HttpConstants.SubStatusCodes.FORBIDDEN_WRITEFORBIDDEN))
{
logger.warn("Endpoint not writable. Will refresh cache and retry. {}", e.toString());
return this.shouldRetryOnEndpointFailureAsync(false);
return this.shouldRetryOnEndpointFailureAsync(false, true);
}

// Regional endpoint is not available yet for reads (e.g. add/ online of region is in progress)
Expand All @@ -92,13 +92,13 @@ public Mono<ShouldRetryResult> shouldRetry(Exception e) {
this.isReadRequest)
{
logger.warn("Endpoint not available for reads. Will refresh cache and retry. {}", e.toString());
return this.shouldRetryOnEndpointFailureAsync(true);
return this.shouldRetryOnEndpointFailureAsync(true, false);
}

// Received Connection error (HttpRequestException), initiate the endpoint rediscovery
if (WebExceptionUtility.isNetworkFailure(e)) {
logger.warn("Endpoint not reachable. Will refresh cache and retry. {}" , e.toString());
return this.shouldRetryOnEndpointFailureAsync(this.isReadRequest);
return this.shouldRetryOnEndpointFailureAsync(this.isReadRequest, false);
}

if (clientException != null &&
Expand Down Expand Up @@ -141,7 +141,7 @@ private ShouldRetryResult shouldRetryOnSessionNotAvailable() {
}
}

private Mono<ShouldRetryResult> shouldRetryOnEndpointFailureAsync(boolean isReadRequest) {
private Mono<ShouldRetryResult> shouldRetryOnEndpointFailureAsync(boolean isReadRequest , boolean forceRefresh) {
if (!this.enableEndpointDiscovery || this.failoverRetryCount > MaxRetryCount) {
logger.warn("ShouldRetryOnEndpointFailureAsync() Not retrying. Retry count = {}", this.failoverRetryCount);
return Mono.just(ShouldRetryResult.noRetry());
Expand Down Expand Up @@ -173,7 +173,7 @@ private Mono<ShouldRetryResult> shouldRetryOnEndpointFailureAsync(boolean isRead
retryDelay = Duration.ofMillis(ClientRetryPolicy.RetryIntervalInMS);
}
this.retryContext = new RetryContext(this.failoverRetryCount, false);
return this.globalEndpointManager.refreshLocationAsync(null)
return this.globalEndpointManager.refreshLocationAsync(null, forceRefresh)
.then(Mono.just(ShouldRetryResult.retryAfter(retryDelay)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class GlobalEndpointManager implements AutoCloseable {
private final ConnectionPolicy connectionPolicy;
private final DatabaseAccountManagerInternal owner;
private final AtomicBoolean isRefreshing;
private final AtomicBoolean refreshInBackground;
private final ExecutorService executor = Executors.newSingleThreadExecutor();
private final Scheduler scheduler = Schedulers.fromExecutor(executor);
private volatile boolean isClosed;
Expand All @@ -63,6 +64,7 @@ public GlobalEndpointManager(DatabaseAccountManagerInternal owner, ConnectionPol
this.connectionPolicy = connectionPolicy;

this.isRefreshing = new AtomicBoolean(false);
this.refreshInBackground = new AtomicBoolean(false);
this.isClosed = false;
} catch (Exception e) {
throw new IllegalArgumentException(e);
Expand Down Expand Up @@ -129,9 +131,24 @@ public void close() {
logger.debug("GlobalEndpointManager closed.");
}

public Mono<Void> refreshLocationAsync(DatabaseAccount databaseAccount) {
public Mono<Void> refreshLocationAsync(DatabaseAccount databaseAccount, boolean forceRefresh) {
return Mono.defer(() -> {
logger.debug("refreshLocationAsync() invoked");

if (forceRefresh) {
Mono<DatabaseAccount> databaseAccountObs = getDatabaseAccountFromAnyLocationsAsync(
this.defaultEndpoint,
new ArrayList<>(this.connectionPolicy.preferredLocations()),
this::getDatabaseAccountAsync);

return databaseAccountObs.map(dbAccount -> {
this.locationCache.onDatabaseAccountRead(dbAccount);
return dbAccount;
}).flatMap(dbAccount -> {
return Mono.empty();
});
}

if (!isRefreshing.compareAndSet(false, true)) {
logger.debug("in the middle of another refresh. Not invoking a new refresh.");
return Mono.empty();
Expand Down Expand Up @@ -164,17 +181,23 @@ private Mono<Void> refreshLocationPrivateAsync(DatabaseAccount databaseAccount)

return databaseAccountObs.map(dbAccount -> {
this.locationCache.onDatabaseAccountRead(dbAccount);
this.isRefreshing.set(false);
return dbAccount;
}).flatMap(dbAccount -> {
// trigger a startRefreshLocationTimerAsync don't wait on it.
this.startRefreshLocationTimerAsync();
if (!this.refreshInBackground.get()) {
this.startRefreshLocationTimerAsync();
}
return Mono.empty();
});
}

// trigger a startRefreshLocationTimerAsync don't wait on it.
this.startRefreshLocationTimerAsync();
if (!this.refreshInBackground.get()) {
this.startRefreshLocationTimerAsync();
}

this.isRefreshing.set(false);
return Mono.empty();
} else {
logger.debug("shouldRefreshEndpoints: false, nothing to do.");
Expand All @@ -201,6 +224,8 @@ private Mono<Void> startRefreshLocationTimerAsync(boolean initialization) {

int delayInMillis = initialization ? 0: this.backgroundRefreshLocationTimeIntervalInMS;

this.refreshInBackground.set(true);

return Mono.delay(Duration.ofMillis(delayInMillis))
.flatMap(
t -> {
Expand All @@ -216,6 +241,7 @@ private Mono<Void> startRefreshLocationTimerAsync(boolean initialization) {

return databaseAccountObs.flatMap(dbAccount -> {
logger.debug("db account retrieved");
this.refreshInBackground.set(false);
return this.refreshLocationPrivateAsync(dbAccount);
});
}).onErrorResume(ex -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ private void initializeGatewayConfigurationReader() {

// TODO: add support for openAsync
// https://msdata.visualstudio.com/CosmosDB/_workitems/edit/332589
this.globalEndpointManager.refreshLocationAsync(databaseAccount).block();
this.globalEndpointManager.refreshLocationAsync(databaseAccount, false).block();
}

public void init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,20 @@ public boolean shouldRefreshEndpoints(Utils.ValueHolder canRefreshInBackground)
if (this.enableEndpointDiscovery) {

boolean shouldRefresh = this.useMultipleWriteLocations && !this.enableMultipleWriteLocations;
List<URL> readLocationEndpoints = currentLocationInfo.readEndpoints;
if (this.isEndpointUnavailable(readLocationEndpoints.get(0), OperationType.Read)) {
// Since most preferred read endpoint is unavailable, we can only refresh in background if
// we have an alternate read endpoint
canRefreshInBackground.v = anyEndpointsAvailable(readLocationEndpoints,OperationType.Read);
logger.debug("shouldRefreshEndpoints = true, since the first read endpoint " +
"[{}] is not available for read. canRefreshInBackground = [{}]",
readLocationEndpoints.get(0),
canRefreshInBackground.v);
return true;
}

if (!Strings.isNullOrEmpty(mostPreferredLocation)) {
Utils.ValueHolder<URL> mostPreferredReadEndpointHolder = new Utils.ValueHolder<>();
List<URL> readLocationEndpoints = currentLocationInfo.readEndpoints;
logger.debug("getReadEndpoints [{}]", readLocationEndpoints);

if (Utils.tryGetValue(currentLocationInfo.availableReadEndpointByLocation, mostPreferredLocation, mostPreferredReadEndpointHolder)) {
Expand Down Expand Up @@ -218,7 +229,7 @@ public boolean shouldRefreshEndpoints(Utils.ValueHolder canRefreshInBackground)
if (this.isEndpointUnavailable(writeLocationEndpoints.get(0), OperationType.Write)) {
// Since most preferred write endpoint is unavailable, we can only refresh in background if
// we have an alternate write endpoint
canRefreshInBackground.v = writeLocationEndpoints.size() > 1;
canRefreshInBackground.v = anyEndpointsAvailable(writeLocationEndpoints,OperationType.Write);
logger.debug("shouldRefreshEndpoints = true, most preferred location " +
"[{}] endpoint [{}] is not available for write. canRefreshInBackground = [{}]",
mostPreferredLocation,
Expand Down Expand Up @@ -305,6 +316,18 @@ private boolean isEndpointUnavailable(URL endpoint, OperationType expectedAvaila
}
}

private boolean anyEndpointsAvailable(List<URL> endpoints, OperationType expectedAvailableOperations) {
Utils.ValueHolder<LocationUnavailabilityInfo> unavailabilityInfoHolder = new Utils.ValueHolder<>();
boolean anyEndpointsAvailable = false;
for (URL endpoint : endpoints) {
if (!isEndpointUnavailable(endpoint, expectedAvailableOperations)) {
anyEndpointsAvailable = true;
break;
}
}
return anyEndpointsAvailable;
}

private void markEndpointUnavailable(
URL unavailableEndpoint,
OperationType unavailableOperationType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.azure.data.cosmos.RetryOptions;
import io.netty.handler.timeout.ReadTimeoutException;
import io.reactivex.subscribers.TestSubscriber;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.testng.annotations.Test;
import reactor.core.publisher.Mono;
Expand All @@ -22,7 +23,7 @@ public void networkFailureOnRead() throws Exception {
RetryOptions retryOptions = new RetryOptions();
GlobalEndpointManager endpointManager = Mockito.mock(GlobalEndpointManager.class);
Mockito.doReturn(new URL("http://localhost")).when(endpointManager).resolveServiceEndpoint(Mockito.any(RxDocumentServiceRequest.class));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null), Mockito.eq(false));
ClientRetryPolicy clientRetryPolicy = new ClientRetryPolicy(endpointManager, true, retryOptions);

Exception exception = ReadTimeoutException.INSTANCE;
Expand Down Expand Up @@ -52,7 +53,7 @@ public void networkFailureOnWrite() throws Exception {
RetryOptions retryOptions = new RetryOptions();
GlobalEndpointManager endpointManager = Mockito.mock(GlobalEndpointManager.class);
Mockito.doReturn(new URL("http://localhost")).when(endpointManager).resolveServiceEndpoint(Mockito.any(RxDocumentServiceRequest.class));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null), Mockito.eq(false));
ClientRetryPolicy clientRetryPolicy = new ClientRetryPolicy(endpointManager, true, retryOptions);

Exception exception = ReadTimeoutException.INSTANCE;
Expand Down Expand Up @@ -80,7 +81,7 @@ public void onBeforeSendRequestNotInvoked() {
RetryOptions retryOptions = new RetryOptions();
GlobalEndpointManager endpointManager = Mockito.mock(GlobalEndpointManager.class);

Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null), Mockito.eq(false));
ClientRetryPolicy clientRetryPolicy = new ClientRetryPolicy(endpointManager, true, retryOptions);

Exception exception = ReadTimeoutException.INSTANCE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class RenameCollectionAwareClientRetryPolicyTest {
@Test(groups = "unit", timeOut = TIMEOUT)
public void onBeforeSendRequestNotInvoked() {
GlobalEndpointManager endpointManager = Mockito.mock(GlobalEndpointManager.class);
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null), Mockito.eq(false));

IRetryPolicyFactory retryPolicyFactory = new RetryPolicy(endpointManager, ConnectionPolicy.defaultPolicy());
RxClientCollectionCache rxClientCollectionCache = Mockito.mock(RxClientCollectionCache.class);
Expand Down Expand Up @@ -51,7 +51,7 @@ public void onBeforeSendRequestNotInvoked() {
@Test(groups = "unit", timeOut = TIMEOUT)
public void shouldRetryWithNotFoundStatusCode() {
GlobalEndpointManager endpointManager = Mockito.mock(GlobalEndpointManager.class);
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null),Mockito.eq(false));
IRetryPolicyFactory retryPolicyFactory = new RetryPolicy(endpointManager, ConnectionPolicy.defaultPolicy());
RxClientCollectionCache rxClientCollectionCache = Mockito.mock(RxClientCollectionCache.class);

Expand All @@ -77,7 +77,7 @@ public void shouldRetryWithNotFoundStatusCode() {
@Test(groups = "unit", timeOut = TIMEOUT)
public void shouldRetryWithNotFoundStatusCodeAndReadSessionNotAvailableSubStatusCode() {
GlobalEndpointManager endpointManager = Mockito.mock(GlobalEndpointManager.class);
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null), Mockito.eq(false));
IRetryPolicyFactory retryPolicyFactory = new RetryPolicy(endpointManager, ConnectionPolicy.defaultPolicy());
RxClientCollectionCache rxClientCollectionCache = Mockito.mock(RxClientCollectionCache.class);

Expand Down Expand Up @@ -114,7 +114,7 @@ public void shouldRetryWithNotFoundStatusCodeAndReadSessionNotAvailableSubStatus
@Test(groups = "unit", timeOut = TIMEOUT)
public void shouldRetryWithGenericException() {
GlobalEndpointManager endpointManager = Mockito.mock(GlobalEndpointManager.class);
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null));
Mockito.doReturn(Mono.empty()).when(endpointManager).refreshLocationAsync(Mockito.eq(null), Mockito.eq(false));
IRetryPolicyFactory retryPolicyFactory = new RetryPolicy(endpointManager, ConnectionPolicy.defaultPolicy());
RxClientCollectionCache rxClientCollectionCache = Mockito.mock(RxClientCollectionCache.class);

Expand Down
Loading