diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md
index fed3081db795..d9b36f4dab5c 100644
--- a/sdk/cosmos/azure-cosmos/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md
@@ -3,12 +3,14 @@
### 4.14.3 (Unreleased)
#### Features Added
+* Added support for Per Partition Automatic Failover. To enable this feature, you must follow the guide [here](https://learn.microsoft.com/azure/cosmos-db/how-to-configure-per-partition-automatic-failover). See [PR 41588](https://github.com/Azure/azure-sdk-for-python/pull/41588).
#### Breaking Changes
#### Bugs Fixed
#### Other Changes
+* Added cross-regional retries for 503 (Service Unavailable) errors. See [PR 41588](https://github.com/Azure/azure-sdk-for-python/pull/41588).
### 4.14.2 (2025-11-14)
diff --git a/sdk/cosmos/azure-cosmos/README.md b/sdk/cosmos/azure-cosmos/README.md
index 013a4bd02a63..ca76ad346ce9 100644
--- a/sdk/cosmos/azure-cosmos/README.md
+++ b/sdk/cosmos/azure-cosmos/README.md
@@ -940,6 +940,11 @@ requests to another region:
- `AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED`: Default is a `90` percent failure rate.
- After a partition reaches a 90 percent failure rate for all requests, the SDK will send requests routed to that partition to another region.
+### Per Partition Automatic Failover (Public Preview)
+Per partition automatic failover enables the SDK to automatically redirect write requests at the partition level to another region based on service-side signals. This feature is available
+only for single write region accounts that have at least one read-only region. When per partition automatic failover is enabled, per partition circuit breaker and cross-region hedging is enabled by default, meaning
+all its configurable options also apply to per partition automatic failover. To enable this feature, follow the guide [here](https://learn.microsoft.com/azure/cosmos-db/how-to-configure-per-partition-automatic-failover).
+
## Troubleshooting
### General
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
index d066135500d1..4cf500a129bd 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
@@ -46,9 +46,13 @@
if TYPE_CHECKING:
from ._cosmos_client_connection import CosmosClientConnection
from .aio._cosmos_client_connection_async import CosmosClientConnection as AsyncClientConnection
+ from ._global_partition_endpoint_manager_per_partition_automatic_failover import (
+ _GlobalPartitionEndpointManagerForPerPartitionAutomaticFailover)
from ._request_object import RequestObject
+ from ._routing.routing_range import PartitionKeyRangeWrapper
# pylint: disable=protected-access
+#cspell:ignore PPAF, ppaf
_COMMON_OPTIONS = {
'initial_headers': 'initialHeaders',
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
index d6a23050c226..0a5e961f7aa9 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
@@ -22,8 +22,9 @@
"""Class for defining internal constants in the Azure Cosmos database service.
"""
-
+from enum import IntEnum
from typing_extensions import Literal
+# cspell:ignore PPAF
# cspell:ignore reranker
@@ -40,6 +41,7 @@ class _Constants:
Name: Literal["name"] = "name"
DatabaseAccountEndpoint: Literal["databaseAccountEndpoint"] = "databaseAccountEndpoint"
DefaultEndpointsRefreshTime: int = 5 * 60 * 1000 # milliseconds
+ EnablePerPartitionFailoverBehavior: Literal["enablePerPartitionFailoverBehavior"] = "enablePerPartitionFailoverBehavior" #pylint: disable=line-too-long
# ServiceDocument Resource
EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations"
@@ -74,6 +76,10 @@ class _Constants:
FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"
FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 90
# -------------------------------------------------------------------------
+ # Only applicable when per partition automatic failover is enabled --------
+ TIMEOUT_ERROR_THRESHOLD_PPAF = "AZURE_COSMOS_TIMEOUT_ERROR_THRESHOLD_FOR_PPAF"
+ TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT: int = 10
+ # -------------------------------------------------------------------------
# Error code translations
ERROR_TRANSLATIONS: dict[int, str] = {
@@ -99,3 +105,22 @@ class Kwargs:
"""Whether to retry write operations if they fail. Used either at client level or request level."""
EXCLUDED_LOCATIONS: Literal["excludedLocations"] = "excludedLocations"
+
+ class UserAgentFeatureFlags(IntEnum):
+ """
+ User agent feature flags.
+ Each flag represents a bit in a number to encode what features are enabled. Therefore, the first feature flag
+ will be 1, the second 2, the third 4, etc. When constructing the user agent suffix, the feature flags will be
+ used to encode a unique number representing the features enabled. This number will be converted into a hex
+ string following the prefix "F" to save space in the user agent as it is limited and appended to the user agent
+ suffix. This number will then be used to determine what features are enabled by decoding the hex string back
+ to a number and checking what bits are set.
+
+ Features being developed should align with the .NET SDK as a source of truth for feature flag assignments:
+ https://github.com/Azure/azure-cosmos-dotnet-v3/blob/master/Microsoft.Azure.Cosmos/src/Diagnostics/UserAgentFeatureFlags.cs
+
+ Example:
+ If the user agent suffix has "F3", this means that flags 1 and 2.
+ """
+ PER_PARTITION_AUTOMATIC_FAILOVER = 1
+ PER_PARTITION_CIRCUIT_BREAKER = 2
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
index 4e08365e2c47..c80c76864e98 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
@@ -49,7 +49,7 @@
HttpResponse # pylint: disable=no-legacy-azure-core-http-response-import
from . import _base as base
-from ._global_partition_endpoint_manager_circuit_breaker import _GlobalPartitionEndpointManagerForCircuitBreaker
+from ._global_partition_endpoint_manager_per_partition_automatic_failover import _GlobalPartitionEndpointManagerForPerPartitionAutomaticFailover # pylint: disable=line-too-long
from . import _query_iterable as query_iterable
from . import _runtime_constants as runtime_constants
from . import _session
@@ -176,7 +176,7 @@ def __init__( # pylint: disable=too-many-statements
self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict()
self.UseMultipleWriteLocations = False
- self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self)
+ self._global_endpoint_manager = _GlobalPartitionEndpointManagerForPerPartitionAutomaticFailover(self)
retry_policy = None
if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy):
@@ -2688,12 +2688,15 @@ def GetDatabaseAccount(
database_account._ReadableLocations = result[Constants.ReadableLocations]
if Constants.EnableMultipleWritableLocations in result:
database_account._EnableMultipleWritableLocations = result[
- Constants.EnableMultipleWritableLocations
- ]
+ Constants.EnableMultipleWritableLocations]
self.UseMultipleWriteLocations = (
self.connection_policy.UseMultipleWriteLocations and database_account._EnableMultipleWritableLocations
)
+
+ if Constants.EnablePerPartitionFailoverBehavior in result:
+ database_account._EnablePerPartitionFailoverBehavior = result[Constants.EnablePerPartitionFailoverBehavior]
+
if response_hook:
response_hook(last_response_headers, result)
return database_account
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_http_logging_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_http_logging_policy.py
index 0741ab6c97be..aa9707a9a441 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_http_logging_policy.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_http_logging_policy.py
@@ -180,7 +180,7 @@ def _get_client_settings(global_endpoint_manager: Optional[_GlobalEndpointManage
gem_client = global_endpoint_manager.client
if gem_client and gem_client.connection_policy:
connection_policy: ConnectionPolicy = gem_client.connection_policy
- client_preferred_regions = connection_policy.PreferredLocations
+ client_preferred_regions = global_endpoint_manager.location_cache.effective_preferred_locations
client_excluded_regions = connection_policy.ExcludedLocations
if global_endpoint_manager.location_cache:
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_endpoint_discovery_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_endpoint_discovery_retry_policy.py
index f113efaafc42..3357c097c63a 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_endpoint_discovery_retry_policy.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_endpoint_discovery_retry_policy.py
@@ -23,16 +23,9 @@
Azure Cosmos database service.
"""
-import logging
-from azure.cosmos.documents import _OperationType
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-log_formatter = logging.Formatter("%(levelname)s:%(message)s")
-log_handler = logging.StreamHandler()
-log_handler.setFormatter(log_formatter)
-logger.addHandler(log_handler)
+# cspell:ignore PPAF
+from azure.cosmos.documents import _OperationType
class EndpointDiscoveryRetryPolicy(object):
"""The endpoint discovery retry policy class used for geo-replicated database accounts
@@ -43,8 +36,9 @@ class EndpointDiscoveryRetryPolicy(object):
Max_retry_attempt_count = 120
Retry_after_in_milliseconds = 1000
- def __init__(self, connection_policy, global_endpoint_manager, *args):
+ def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args):
self.global_endpoint_manager = global_endpoint_manager
+ self.pk_range_wrapper = pk_range_wrapper
self._max_retry_attempt_count = EndpointDiscoveryRetryPolicy.Max_retry_attempt_count
self.failover_retry_count = 0
self.retry_after_in_milliseconds = EndpointDiscoveryRetryPolicy.Retry_after_in_milliseconds
@@ -70,6 +64,22 @@ def ShouldRetry(self, exception): # pylint: disable=unused-argument
self.failover_retry_count += 1
+ # set the refresh_needed flag to ensure that endpoint list is
+ # refreshed with new writable and readable locations
+ self.global_endpoint_manager.refresh_needed = True
+
+ # If per partition automatic failover is applicable, we mark the current endpoint as unavailable
+ # and resolve the service endpoint for the partition range - otherwise, continue the default retry logic
+ if self.global_endpoint_manager.is_per_partition_automatic_failover_applicable(self.request):
+ partition_level_info = self.global_endpoint_manager.partition_range_to_failover_info[self.pk_range_wrapper]
+ location = self.global_endpoint_manager.location_cache.get_location_from_endpoint(
+ str(self.request.location_endpoint_to_route))
+ regional_endpoint = (self.global_endpoint_manager.location_cache.
+ account_read_regional_routing_contexts_by_location.get(location))
+ partition_level_info.unavailable_regional_endpoints[location] = regional_endpoint
+ self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)
+ return True
+
if self.request.location_endpoint_to_route:
context = self.__class__.__name__
if _OperationType.IsReadOnlyOperation(self.request.operation_type):
@@ -82,16 +92,11 @@ def ShouldRetry(self, exception): # pylint: disable=unused-argument
self.request.location_endpoint_to_route,
True, context)
- # set the refresh_needed flag to ensure that endpoint list is
- # refreshed with new writable and readable locations
- self.global_endpoint_manager.refresh_needed = True
-
# clear previous location-based routing directive
self.request.clear_route_to_location()
# set location-based routing directive based on retry count
- # simulating single master writes by ensuring usePreferredLocations
- # is set to false
+ # simulating single master writes by ensuring usePreferredLocations is set to false
# reasoning being that 403.3 is only expected for write region failover in single writer account
# and we must rely on account locations as they are the source of truth
self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, False)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py
index 34d27d4b7907..e301e4c4d49f 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py
@@ -36,6 +36,8 @@
if TYPE_CHECKING:
from azure.cosmos._cosmos_client_connection import CosmosClientConnection
+#cspell:ignore ppcb
+
class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager):
"""
This internal class implements the logic for partition endpoint management for
@@ -93,16 +95,17 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK
return PartitionKeyRangeWrapper(partition_range, container_rid)
- def record_failure(
+ def record_ppcb_failure(
self,
- request: RequestObject
- ) -> None:
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None)-> None:
if self.is_circuit_breaker_applicable(request):
- pk_range_wrapper = self.create_pk_range_wrapper(request)
+ if pk_range_wrapper is None:
+ pk_range_wrapper = self.create_pk_range_wrapper(request)
if pk_range_wrapper:
self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper)
- def resolve_service_endpoint_for_partition(
+ def _resolve_service_endpoint_for_partition_circuit_breaker(
self,
request: RequestObject,
pk_range_wrapper: Optional[PartitionKeyRangeWrapper]
@@ -113,11 +116,12 @@ def resolve_service_endpoint_for_partition(
pk_range_wrapper)
return self._resolve_service_endpoint(request)
- def record_success(
+ def record_ppcb_success(
self,
- request: RequestObject
- ) -> None:
- if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request):
- pk_range_wrapper = self.create_pk_range_wrapper(request)
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None) -> None:
+ if self.is_circuit_breaker_applicable(request):
+ if pk_range_wrapper is None:
+ pk_range_wrapper = self.create_pk_range_wrapper(request)
if pk_range_wrapper:
self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py
index 93faf9b7a8c5..91fd67805a18 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py
@@ -19,6 +19,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+# pylint: disable=protected-access
+
"""Internal class for global endpoint manager for circuit breaker.
"""
import logging
@@ -60,7 +62,10 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool:
return False
circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG,
- Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True"
+ Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT).lower() == "true"
+ if not circuit_breaker_enabled and self.client._global_endpoint_manager is not None:
+ if self.client._global_endpoint_manager._database_account_cache is not None:
+ circuit_breaker_enabled = self.client._global_endpoint_manager._database_account_cache._EnablePerPartitionFailoverBehavior is True # pylint: disable=line-too-long
if not circuit_breaker_enabled:
return False
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_per_partition_automatic_failover.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_per_partition_automatic_failover.py
new file mode 100644
index 000000000000..0547cb41df32
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_per_partition_automatic_failover.py
@@ -0,0 +1,241 @@
+# The MIT License (MIT)
+# Copyright (c) 2025 Microsoft Corporation
+
+"""Class for global endpoint manager for per partition automatic failover. This class inherits the circuit breaker
+endpoint manager, since enabling per partition automatic failover also enables the circuit breaker logic.
+"""
+import logging
+import threading
+import os
+
+from typing import TYPE_CHECKING, Optional
+
+from azure.cosmos.http_constants import ResourceType
+from azure.cosmos._constants import _Constants as Constants
+from azure.cosmos._global_partition_endpoint_manager_circuit_breaker import \
+ _GlobalPartitionEndpointManagerForCircuitBreaker
+from azure.cosmos._partition_health_tracker import _PPAFPartitionThresholdsTracker
+from azure.cosmos.documents import _OperationType
+from azure.cosmos._request_object import RequestObject
+from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper
+
+if TYPE_CHECKING:
+ from azure.cosmos._cosmos_client_connection import CosmosClientConnection
+ from azure.cosmos._location_cache import RegionalRoutingContext
+
+logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForPerPartitionAutomaticFailover")
+
+# pylint: disable=name-too-long, protected-access, too-many-nested-blocks
+#cspell:ignore PPAF, ppaf, ppcb
+
+class PartitionLevelFailoverInfo:
+ """
+ Holds information about the partition level regional failover.
+ Used to track the partition key range and the regions where it is available.
+ """
+ def __init__(self) -> None:
+ self.unavailable_regional_endpoints: dict[str, "RegionalRoutingContext"] = {}
+ self._lock = threading.Lock()
+ self.current_region: Optional[str] = None
+
+ def try_move_to_next_location(
+ self,
+ available_account_regional_endpoints: dict[str, "RegionalRoutingContext"],
+ endpoint_region: str,
+ request: RequestObject) -> bool:
+ """
+ Tries to move to the next available regional endpoint for the partition key range.
+ :param dict[str, RegionalRoutingContext] available_account_regional_endpoints: The available regional endpoints
+ :param str endpoint_region: The current regional endpoint
+ :param RequestObject request: The request object containing the routing context.
+ :return: True if the move was successful, False otherwise.
+ :rtype: bool
+ """
+ with self._lock:
+ if endpoint_region != self.current_region and self.current_region is not None:
+ regional_endpoint = available_account_regional_endpoints[self.current_region].primary_endpoint
+ request.route_to_location(regional_endpoint)
+ return True
+
+ for regional_endpoint in available_account_regional_endpoints:
+ if regional_endpoint == self.current_region:
+ continue
+
+ if regional_endpoint in self.unavailable_regional_endpoints:
+ continue
+
+ self.current_region = regional_endpoint
+ logger.warning("PPAF - Moving to next available regional endpoint: %s", self.current_region)
+ regional_endpoint = available_account_regional_endpoints[self.current_region].primary_endpoint
+ request.route_to_location(regional_endpoint)
+ return True
+
+ return False
+
+class _GlobalPartitionEndpointManagerForPerPartitionAutomaticFailover(_GlobalPartitionEndpointManagerForCircuitBreaker):
+ """
+ This internal class implements the logic for partition endpoint management for
+ geo-replicated database accounts.
+ """
+ def __init__(self, client: "CosmosClientConnection") -> None:
+ super(_GlobalPartitionEndpointManagerForPerPartitionAutomaticFailover, self).__init__(client)
+ self.partition_range_to_failover_info: dict[PartitionKeyRangeWrapper, PartitionLevelFailoverInfo] = {}
+ self.ppaf_thresholds_tracker = _PPAFPartitionThresholdsTracker()
+ self._threshold_lock = threading.Lock()
+
+ def is_per_partition_automatic_failover_enabled(self) -> bool:
+ if not self._database_account_cache or not self._database_account_cache._EnablePerPartitionFailoverBehavior:
+ return False
+ return True
+
+ def is_per_partition_automatic_failover_applicable(self, request: RequestObject) -> bool:
+ if not self.is_per_partition_automatic_failover_enabled():
+ return False
+
+ if not request:
+ return False
+
+ if (self.location_cache.can_use_multiple_write_locations_for_request(request)
+ or _OperationType.IsReadOnlyOperation(request.operation_type)):
+ return False
+
+ # if we have at most one region available in the account, we cannot do per partition automatic failover
+ available_regions = self.location_cache.account_read_regional_routing_contexts_by_location
+ if len(available_regions) <= 1:
+ return False
+
+ # if the request is not a non-query plan document request
+ # or if the request is not executing a stored procedure, return False
+ if (request.resource_type != ResourceType.Document and
+ request.operation_type != _OperationType.ExecuteJavaScript):
+ return False
+
+ return True
+
+ def try_ppaf_failover_threshold(
+ self,
+ pk_range_wrapper: "PartitionKeyRangeWrapper",
+ request: "RequestObject"):
+ """Verifies whether the per-partition failover threshold has been reached for consecutive errors. If so,
+ it marks the current region as unavailable for the given partition key range, and moves to the next available
+ region for the request.
+
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :param RequestObject request: The request object containing the routing context.
+ :returns: None
+ """
+ # If PPAF is enabled, we track consecutive failures for certain exceptions, and only fail over at a partition
+ # level after the threshold is reached
+ if request and self.is_per_partition_automatic_failover_applicable(request):
+ if (self.ppaf_thresholds_tracker.get_pk_failures(pk_range_wrapper)
+ >= int(os.environ.get(Constants.TIMEOUT_ERROR_THRESHOLD_PPAF,
+ Constants.TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT))):
+ # If the PPAF threshold is reached, we reset the count and mark the endpoint unavailable
+ # Once we mark the endpoint unavailable, the PPAF endpoint manager will try to move to the next
+ # available region for the partition key range
+ with self._threshold_lock:
+ # Check for count again, since a previous request may have now reset the count
+ if (self.ppaf_thresholds_tracker.get_pk_failures(pk_range_wrapper)
+ >= int(os.environ.get(Constants.TIMEOUT_ERROR_THRESHOLD_PPAF,
+ Constants.TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT))):
+ self.ppaf_thresholds_tracker.clear_pk_failures(pk_range_wrapper)
+ partition_level_info = self.partition_range_to_failover_info[pk_range_wrapper]
+ location = self.location_cache.get_location_from_endpoint(
+ str(request.location_endpoint_to_route))
+ logger.warning("PPAF - Failover threshold reached for partition key range: %s for region: %s", #pylint: disable=line-too-long
+ pk_range_wrapper, location)
+ regional_context = (self.location_cache.
+ account_read_regional_routing_contexts_by_location.
+ get(location).primary_endpoint)
+ partition_level_info.unavailable_regional_endpoints[location] = regional_context
+
+ def resolve_service_endpoint_for_partition(
+ self,
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper]
+ ) -> str:
+ """Resolves the endpoint to be used for the request. In a PPAF-enabled account, this method checks whether
+ the partition key range has any unavailable regions, and if so, it tries to move to the next available region.
+ If all regions are unavailable, it invalidates the cache and starts once again from the main write region in the
+ account configurations.
+
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :param RequestObject request: The request object containing the routing context.
+ :returns: The regional endpoint to be used for the request.
+ :rtype: str
+ """
+ if self.is_per_partition_automatic_failover_applicable(request) and pk_range_wrapper:
+ # If per partition automatic failover is applicable, we check partition unavailability
+ if pk_range_wrapper in self.partition_range_to_failover_info:
+ partition_failover_info = self.partition_range_to_failover_info[pk_range_wrapper]
+ if request.location_endpoint_to_route is not None:
+ endpoint_region = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route)
+ if endpoint_region in partition_failover_info.unavailable_regional_endpoints:
+ available_account_regional_endpoints = self.location_cache.account_read_regional_routing_contexts_by_location #pylint: disable=line-too-long
+ if (partition_failover_info.current_region is not None and
+ endpoint_region != partition_failover_info.current_region):
+ # this request has not yet seen there's an available region being used for this partition
+ regional_endpoint = available_account_regional_endpoints[
+ partition_failover_info.current_region].primary_endpoint
+ request.route_to_location(regional_endpoint)
+ else:
+ if (len(self.location_cache.account_read_regional_routing_contexts_by_location)
+ == len(partition_failover_info.unavailable_regional_endpoints)):
+ # If no other region is available, we invalidate the cache and start once again
+ # from our main write region in the account configurations
+ logger.warning("PPAF - All available regions for partition %s are unavailable."
+ " Refreshing cache.", pk_range_wrapper)
+ self.partition_range_to_failover_info[pk_range_wrapper] = PartitionLevelFailoverInfo()
+ request.clear_route_to_location()
+ else:
+ # If the current region is unavailable, we try to move to the next available region
+ partition_failover_info.try_move_to_next_location(
+ self.location_cache.account_read_regional_routing_contexts_by_location,
+ endpoint_region,
+ request)
+ else:
+ # Update the current regional endpoint to whatever the request is routing to
+ partition_failover_info.current_region = endpoint_region
+ else:
+ partition_failover_info = PartitionLevelFailoverInfo()
+ endpoint_region = self.location_cache.get_location_from_endpoint(
+ request.location_endpoint_to_route)
+ partition_failover_info.current_region = endpoint_region
+ self.partition_range_to_failover_info[pk_range_wrapper] = partition_failover_info
+ return self._resolve_service_endpoint_for_partition_circuit_breaker(request, pk_range_wrapper)
+
+ def record_failure(self,
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None) -> None:
+ """Records a failure for the given partition key range and request.
+ :param RequestObject request: The request object containing the routing context.
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :return: None
+ """
+ if self.is_per_partition_automatic_failover_applicable(request):
+ if pk_range_wrapper is None:
+ pk_range_wrapper = self.create_pk_range_wrapper(request)
+ if pk_range_wrapper:
+ self.ppaf_thresholds_tracker.add_failure(pk_range_wrapper)
+ else:
+ self.record_ppcb_failure(request, pk_range_wrapper)
+
+ def record_success(self,
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None) -> None:
+ """Records a success for the given partition key range and request, effectively clearing the failure count.
+ :param RequestObject request: The request object containing the routing context.
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :return: None
+ """
+ if self.is_per_partition_automatic_failover_applicable(request):
+ if pk_range_wrapper is None:
+ pk_range_wrapper = self.create_pk_range_wrapper(request)
+ if pk_range_wrapper:
+ self.ppaf_thresholds_tracker.clear_pk_failures(pk_range_wrapper)
+ else:
+ self.record_ppcb_success(request, pk_range_wrapper)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
index cf8239488712..40b46cf0f42f 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
@@ -488,7 +488,7 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
)
def get_preferred_regional_routing_contexts(
- self, endpoints_by_location, orderedLocations, expected_available_operation, fallback_endpoint
+ self, endpoints_by_location, ordered_locations, expected_available_operation, fallback_endpoint
):
regional_endpoints = []
# if enableEndpointDiscovery is false, we always use the defaultEndpoint that
@@ -522,7 +522,7 @@ def get_preferred_regional_routing_contexts(
if not regional_endpoints:
regional_endpoints.append(fallback_endpoint)
else:
- for location in orderedLocations:
+ for location in ordered_locations:
if location and location in endpoints_by_location:
# location is empty during manual failover
regional_endpoint = endpoints_by_location[location]
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py
index 378540b89119..50f4c79bceb4 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py
@@ -44,6 +44,8 @@
LAST_UNAVAILABILITY_CHECK_TIME_STAMP = "lastUnavailabilityCheckTimeStamp"
HEALTH_STATUS = "healthStatus"
+#cspell:ignore PPAF
+
class _PartitionHealthInfo(object):
"""
This internal class keeps the health and statistics for a partition.
@@ -299,3 +301,28 @@ def _reset_partition_health_tracker_stats(self) -> None:
for locations in self.pk_range_wrapper_to_health_info.values():
for health_info in locations.values():
health_info.reset_failure_rate_health_stats()
+
+class _PPAFPartitionThresholdsTracker(object):
+ """
+ This internal class implements the logic for tracking consecutive failure thresholds for a partition
+ in the context for per-partition automatic failover. This tracker is only used in the context of 408, 5xx and
+ ServiceResponseError errors as a defensive measure to avoid failing over too early without confirmation
+ from the service.
+ """
+
+ def __init__(self) -> None:
+ self.pk_range_wrapper_to_failure_count: dict[PartitionKeyRangeWrapper, int] = {}
+ self._failure_lock = threading.Lock()
+
+ def add_failure(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None:
+ with self._failure_lock:
+ if pk_range_wrapper not in self.pk_range_wrapper_to_failure_count:
+ self.pk_range_wrapper_to_failure_count[pk_range_wrapper] = 0
+ self.pk_range_wrapper_to_failure_count[pk_range_wrapper] += 1
+
+ def clear_pk_failures(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None:
+ if pk_range_wrapper in self.pk_range_wrapper_to_failure_count:
+ del self.pk_range_wrapper_to_failure_count[pk_range_wrapper]
+
+ def get_pk_failures(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> int:
+ return self.pk_range_wrapper_to_failure_count.get(pk_range_wrapper, 0)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
index 9b9153308db2..b52a957c4be0 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
@@ -30,7 +30,7 @@
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import RetryPolicy
-from . import _container_recreate_retry_policy, _health_check_retry_policy
+from . import _container_recreate_retry_policy, _health_check_retry_policy, _service_unavailable_retry_policy
from . import _default_retry_policy
from . import _endpoint_discovery_retry_policy
from . import _gone_retry_policy
@@ -46,6 +46,7 @@
# pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches
+# cspell:ignore PPAF,ppaf,ppcb
# args [0] is the request object
# args [1] is the connection policy
@@ -65,11 +66,12 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
:rtype: tuple of (dict, dict)
"""
pk_range_wrapper = None
- if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]):
+ if args and (global_endpoint_manager.is_per_partition_automatic_failover_applicable(args[0]) or
+ global_endpoint_manager.is_circuit_breaker_applicable(args[0])):
pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0])
# instantiate all retry policies here to be applied for each request execution
endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy(
- client.connection_policy, global_endpoint_manager, *args
+ client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args
)
health_check_retry_policy = _health_check_retry_policy.HealthCheckRetryPolicy(
client.connection_policy, *args
@@ -96,8 +98,11 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy(
client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args,
)
+ service_unavailable_retry_policy = _service_unavailable_retry_policy._ServiceUnavailableRetryPolicy(
+ client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args)
# Get logger
logger = kwargs.get("logger", logging.getLogger("azure.cosmos._retry_utility"))
+
# HttpRequest we would need to modify for Container Recreate Retry Policy
request = None
if args and len(args) > 3:
@@ -115,7 +120,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
try:
if args:
result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs)
- global_endpoint_manager.record_success(args[0])
+ global_endpoint_manager.record_success(args[0], pk_range_wrapper)
else:
result = ExecuteFunction(function, *args, **kwargs)
if not client.last_response_headers:
@@ -202,10 +207,15 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
retry_policy.container_rid = cached_container["_rid"]
request.headers[retry_policy._intended_headers] = retry_policy.container_rid
- elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR:
+ elif e.status_code == StatusCodes.SERVICE_UNAVAILABLE:
if args:
# record the failure for circuit breaker tracking
- global_endpoint_manager.record_failure(args[0])
+ global_endpoint_manager.record_ppcb_failure(args[0], pk_range_wrapper)
+ retry_policy = service_unavailable_retry_policy
+ elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR:
+ if args:
+ # record the failure for ppaf/circuit breaker tracking
+ global_endpoint_manager.record_failure(args[0], pk_range_wrapper)
retry_policy = timeout_failover_retry_policy
else:
retry_policy = defaultRetry_policy
@@ -238,6 +248,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
if not health_check_retry_policy.ShouldRetry(e):
raise e
else:
+ if args:
+ global_endpoint_manager.record_failure(args[0], pk_range_wrapper)
_handle_service_request_retries(client, service_request_retry_policy, e, *args)
except ServiceResponseError as e:
@@ -246,7 +258,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
raise e
else:
if args:
- global_endpoint_manager.record_failure(args[0])
+ global_endpoint_manager.record_failure(args[0], pk_range_wrapper)
_handle_service_response_retries(request, client, service_response_retry_policy, e, *args)
def ExecuteFunction(function, *args, **kwargs):
@@ -283,7 +295,8 @@ def _handle_service_request_retries(
raise exception
def _handle_service_response_retries(request, client, response_retry_policy, exception, *args):
- if request and (_has_read_retryable_headers(request.headers) or (args and is_write_retryable(args[0], client))):
+ if request and (_has_read_retryable_headers(request.headers) or (args and (is_write_retryable(args[0], client) or
+ client._global_endpoint_manager.is_per_partition_automatic_failover_applicable(args[0])))):
# we resolve the request endpoint to the next preferred region
# once we are out of preferred regions we stop retrying
retry_policy = response_retry_policy
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py
index 1192e14351cf..bafa9f1a2777 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py
@@ -6,6 +6,7 @@
from the service, and as such we do not know what the output of the operation was. As such, we
only do cross regional retries for read operations.
"""
+#cspell:ignore PPAF, ppaf
from azure.cosmos.documents import _OperationType
@@ -42,7 +43,9 @@ def ShouldRetry(self):
return False
if self.request:
-
+ # We track consecutive failures for per partition automatic failover, and only fail over at a partition
+ # level after the threshold is reached
+ self.global_endpoint_manager.try_ppaf_failover_threshold(self.pk_range_wrapper, self.request)
if not _OperationType.IsReadOnlyOperation(self.request.operation_type) and not self.request.retry_write > 0:
return False
if self.request.retry_write > 0 and self.failover_retry_count + 1 >= self.max_write_retry_count:
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_unavailable_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_unavailable_retry_policy.py
new file mode 100644
index 000000000000..a210f9348f89
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_unavailable_retry_policy.py
@@ -0,0 +1,79 @@
+# The MIT License (MIT)
+# Copyright (c) Microsoft Corporation. All rights reserved.
+
+"""Internal class for service unavailable errors implementation in the Azure Cosmos database service.
+
+Service unavailable errors can occur when a request does not make it to the service, or when there is an issue with
+the service. In either case, we know the request did not get processed successfully, so service unavailable errors are
+ retried in the next available preferred region.
+"""
+from azure.cosmos.documents import _OperationType
+from azure.cosmos.exceptions import CosmosHttpResponseError
+
+#cspell:ignore ppaf
+
+class _ServiceUnavailableRetryPolicy(object):
+ def __init__(
+ self,
+ connection_policy,
+ global_endpoint_manager,
+ pk_range_wrapper,
+ *args):
+ self.retry_after_in_milliseconds = 500
+ self.global_endpoint_manager = global_endpoint_manager
+ self.pk_range_wrapper = pk_range_wrapper
+ self.retry_count = 0
+ self.connection_policy = connection_policy
+ self.request = args[0] if args else None
+ # If an account only has 1 region, then we still want to retry once on the same region
+ self._max_retry_attempt_count = len(self.global_endpoint_manager.
+ location_cache.read_regional_routing_contexts) + 1
+ if self.request and _OperationType.IsWriteOperation(self.request.operation_type):
+ self._max_retry_attempt_count = len(self.global_endpoint_manager.location_cache.
+ write_regional_routing_contexts) + 1
+
+ def ShouldRetry(self, _exception: CosmosHttpResponseError):
+ """Returns true if the request should retry based on the passed-in exception.
+
+ :param exceptions.CosmosHttpResponseError _exception:
+ :returns: a boolean stating whether the request should be retried
+ :rtype: bool
+ """
+ # writes are retried for 503s
+ if not self.connection_policy.EnableEndpointDiscovery:
+ return False
+
+ self.retry_count += 1
+ # Check if the next retry about to be done is safe
+ if self.retry_count >= self._max_retry_attempt_count:
+ return False
+
+ if self.request:
+ # If per partition automatic failover is applicable, we mark the current endpoint as unavailable
+ # and resolve the service endpoint for the partition range - otherwise, continue the default retry logic
+ if self.global_endpoint_manager.is_per_partition_automatic_failover_applicable(self.request):
+ partition_level_info = self.global_endpoint_manager.partition_range_to_failover_info[
+ self.pk_range_wrapper]
+ location = self.global_endpoint_manager.location_cache.get_location_from_endpoint(
+ str(self.request.location_endpoint_to_route))
+ regional_context = (self.global_endpoint_manager.location_cache.
+ account_read_regional_routing_contexts_by_location.get(location))
+ partition_level_info.unavailable_regional_endpoints[location] = regional_context
+ self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)
+ return True
+ location_endpoint = self.resolve_next_region_service_endpoint()
+ self.request.route_to_location(location_endpoint)
+ return True
+
+ # This function prepares the request to go to the next region
+ def resolve_next_region_service_endpoint(self):
+ # clear previous location-based routing directive
+ self.request.clear_route_to_location()
+ # clear the last routed endpoint within same region since we are going to a new region now
+ self.request.last_routed_location_endpoint_within_region = None
+ # set location-based routing directive based on retry count
+ # ensuring usePreferredLocations is set to True for retry
+ self.request.route_to_location_with_preferred_location_flag(self.retry_count, True)
+ # Resolve the endpoint for the request and pin the resolution to the resolved endpoint
+ # This enables marking the endpoint unavailability on endpoint failover/unreachability
+ return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py
index f10366ac4c7f..e11cb4838047 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py
@@ -22,7 +22,7 @@
"""Internal class for session read/write unavailable retry policy implementation
in the Azure Cosmos database service.
"""
-
+# cspell:disable
from azure.cosmos.documents import _OperationType
class _SessionRetryPolicy(object):
@@ -60,16 +60,12 @@ def ShouldRetry(self, _exception):
:returns: a boolean stating whether the request should be retried
:rtype: bool
"""
- if not self.request:
+ if not self.request or not self.endpoint_discovery_enable:
return False
self.session_token_retry_count += 1
# clear previous location-based routing directive
self.request.clear_route_to_location()
- if not self.endpoint_discovery_enable:
- # if endpoint discovery is disabled, the request cannot be retried anywhere else
- return False
-
if self.can_use_multiple_write_locations:
if _OperationType.IsReadOnlyOperation(self.request.operation_type):
locations = self.global_endpoint_manager.get_ordered_read_locations()
@@ -105,6 +101,22 @@ def ShouldRetry(self, _exception):
self.request.route_to_location_with_preferred_location_flag(self.session_token_retry_count - 1, False)
self.request.should_clear_session_token_on_session_read_failure = True
+ # For PPAF, the retry should happen to whatever the relevant write region is for the affected partition.
+ if self.global_endpoint_manager.is_per_partition_automatic_failover_enabled():
+ pk_failover_info = self.global_endpoint_manager.partition_range_to_failover_info.get(self.pk_range_wrapper)
+ if pk_failover_info is not None:
+ location = self.global_endpoint_manager.location_cache.get_location_from_endpoint(
+ str(self.request.location_endpoint_to_route))
+ if location in pk_failover_info.unavailable_regional_endpoints:
+ # If the request endpoint is unavailable, we need to resolve the endpoint for the request using the
+ # partition-level failover info
+ if pk_failover_info.current_region is not None:
+ location_endpoint = (self.global_endpoint_manager.location_cache.
+ account_read_regional_routing_contexts_by_location.
+ get(pk_failover_info.current_region).primary_endpoint)
+ self.request.route_to_location(location_endpoint)
+ return True
+
# Resolve the endpoint for the request and pin the resolution to the resolved endpoint
# This enables marking the endpoint unavailability on endpoint failover/unreachability
self.location_endpoint = (self.global_endpoint_manager
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py
index 1b2290981307..55f9ac40a00c 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py
@@ -28,9 +28,8 @@
from urllib.parse import urlparse
from azure.core.exceptions import DecodeError # type: ignore
-from . import exceptions
-from . import http_constants
-from . import _retry_utility
+from . import exceptions, http_constants, _retry_utility
+from ._utils import get_user_agent_features
def _is_readable_stream(obj):
@@ -80,7 +79,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
:rtype: tuple of (dict, dict)
"""
- # pylint: disable=protected-access
+ # pylint: disable=protected-access, too-many-branches
connection_timeout = connection_policy.RequestTimeout
connection_timeout = kwargs.pop("connection_timeout", connection_timeout)
@@ -111,8 +110,9 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
base_url = request_params.endpoint_override
else:
pk_range_wrapper = None
- if global_endpoint_manager.is_circuit_breaker_applicable(request_params):
- # Circuit breaker is applicable, so we need to use the endpoint from the request
+ if (global_endpoint_manager.is_circuit_breaker_applicable(request_params) or
+ global_endpoint_manager.is_per_partition_automatic_failover_applicable(request_params)):
+ # Circuit breaker or per-partition failover are applicable, so we need to use the endpoint from the request
pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(request_params)
base_url = global_endpoint_manager.resolve_service_endpoint_for_partition(request_params, pk_range_wrapper)
if not request.url.startswith(base_url):
@@ -120,6 +120,15 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
parse_result = urlparse(request.url)
+ # Add relevant enabled features to user agent for debugging
+ if request.headers[http_constants.HttpHeaders.ThinClientProxyResourceType] == http_constants.ResourceType.Document:
+ user_agent_features = get_user_agent_features(global_endpoint_manager)
+ if len(user_agent_features) > 0:
+ user_agent = kwargs.pop("user_agent", global_endpoint_manager.client._user_agent)
+ user_agent = "{} {}".format(user_agent, user_agent_features)
+ kwargs.update({"user_agent": user_agent})
+ kwargs.update({"user_agent_overwrite": True})
+
# The requests library now expects header values to be strings only starting 2.11,
# and will raise an error on validation if they are not, so casting all header values to strings.
request.headers.update({header: str(value) for header, value in request.headers.items()})
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py
index 4fa82e83e2b1..801dd350e121 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py
@@ -6,6 +6,7 @@
"""
from azure.cosmos.documents import _OperationType
+# cspell:ignore PPAF, ppaf
class _TimeoutFailoverRetryPolicy(object):
@@ -19,13 +20,14 @@ def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper,
# If an account only has 1 region, then we still want to retry once on the same region
# We want this to be the default retry attempts as paging through a query means there are requests without
# a request object
- self._max_retry_attempt_count = len(self.global_endpoint_manager.location_cache
- .read_regional_routing_contexts) + 1
+ self._max_retry_attempt_count = len(self.global_endpoint_manager.
+ location_cache.read_regional_routing_contexts) + 1
# If the request is a write operation, we only want to retry as many times as retry_write
if self.request and _OperationType.IsWriteOperation(self.request.operation_type):
self._max_retry_attempt_count = self.request.retry_write
self.retry_count = 0
self.connection_policy = connection_policy
+ self.request = args[0] if args else None
def ShouldRetry(self, _exception):
"""Returns true if the request should retry based on the passed-in exception.
@@ -34,6 +36,8 @@ def ShouldRetry(self, _exception):
:returns: a boolean stating whether the request should be retried
:rtype: bool
"""
+ self.global_endpoint_manager.try_ppaf_failover_threshold(self.pk_range_wrapper, self.request)
+
# we retry only if the request is a read operation or if it is a write operation with retry enabled
if self.request and not self.is_operation_retryable():
return False
@@ -48,7 +52,7 @@ def ShouldRetry(self, _exception):
# second check here ensures we only do cross-regional retries for read requests
# non-idempotent write retries should only be retried once, using preferred locations if available (MM)
- if self.request and (_OperationType.IsReadOnlyOperation(self.request.operation_type)
+ if self.request and (self.is_operation_retryable()
or self.global_endpoint_manager.can_use_multiple_write_locations(self.request)):
location_endpoint = self.resolve_next_region_service_endpoint()
self.request.route_to_location(location_endpoint)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py
index 9144afca613d..0587556f198e 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py
@@ -27,10 +27,13 @@
import base64
import json
import time
+import os
from typing import Any, Optional, Tuple
-
+from ._constants import _Constants
from ._version import VERSION
+# cspell:ignore ppcb
+# pylint: disable=protected-access
def get_user_agent(suffix: Optional[str] = None) -> str:
os_name = safe_user_agent_header(platform.platform())
@@ -146,3 +149,26 @@ def valid_key_value_exist(
:rtype: bool
"""
return key in kwargs and kwargs[key] is not invalid_value
+
+
+def get_user_agent_features(global_endpoint_manager: Any) -> str:
+ """
+ Check the account and client configurations in order to add feature flags
+ to the user agent using bitmask logic and hex encoding (matching .NET/Java).
+
+ :param Any global_endpoint_manager: The GlobalEndpointManager instance.
+ :return: A string representing the user agent feature flags.
+ :rtype: str
+ """
+ feature_flag = 0
+ # Bitwise OR for feature flags
+ if global_endpoint_manager._database_account_cache is not None:
+ if global_endpoint_manager._database_account_cache._EnablePerPartitionFailoverBehavior is True:
+ feature_flag |= _Constants.UserAgentFeatureFlags.PER_PARTITION_AUTOMATIC_FAILOVER
+ ppcb_check = os.environ.get(
+ _Constants.CIRCUIT_BREAKER_ENABLED_CONFIG,
+ _Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT
+ ).lower()
+ if ppcb_check == "true" or feature_flag > 0:
+ feature_flag |= _Constants.UserAgentFeatureFlags.PER_PARTITION_CIRCUIT_BREAKER
+ return f"| F{feature_flag:X}" if feature_flag > 0 else ""
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py
index c19d9b494abb..2d8c7e313a62 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py
@@ -32,6 +32,7 @@
from .. import http_constants
from . import _retry_utility_async
from .._synchronized_request import _request_body_from_data, _replace_url_prefix
+from .._utils import get_user_agent_features
async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements
@@ -49,7 +50,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p
:rtype: tuple of (dict, dict)
"""
- # pylint: disable=protected-access
+ # pylint: disable=protected-access, too-many-branches
connection_timeout = connection_policy.RequestTimeout
read_timeout = connection_policy.ReadTimeout
@@ -80,8 +81,9 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p
base_url = request_params.endpoint_override
else:
pk_range_wrapper = None
- if global_endpoint_manager.is_circuit_breaker_applicable(request_params):
- # Circuit breaker is applicable, so we need to use the endpoint from the request
+ if (global_endpoint_manager.is_circuit_breaker_applicable(request_params) or
+ global_endpoint_manager.is_per_partition_automatic_failover_applicable(request_params)):
+ # Circuit breaker or per-partition failover are applicable, so we need to use the endpoint from the request
pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(request_params)
base_url = global_endpoint_manager.resolve_service_endpoint_for_partition(request_params, pk_range_wrapper)
if not request.url.startswith(base_url):
@@ -89,6 +91,15 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p
parse_result = urlparse(request.url)
+ # Add relevant enabled features to user agent for debugging
+ if request.headers[http_constants.HttpHeaders.ThinClientProxyResourceType] == http_constants.ResourceType.Document:
+ user_agent_features = get_user_agent_features(global_endpoint_manager)
+ if len(user_agent_features) > 0:
+ user_agent = kwargs.pop("user_agent", global_endpoint_manager.client._user_agent)
+ user_agent = "{} {}".format(user_agent, user_agent_features)
+ kwargs.update({"user_agent": user_agent})
+ kwargs.update({"user_agent_overwrite": True})
+
# The requests library now expects header values to be strings only starting 2.11,
# and will raise an error on validation if they are not, so casting all header values to strings.
request.headers.update({header: str(value) for header, value in request.headers.items()})
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py
index 67d5d4efa3e9..a33f16f0df6d 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py
@@ -44,10 +44,8 @@
DistributedTracingPolicy,
ProxyPolicy)
from azure.core.utils import CaseInsensitiveDict
-from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import (
- _GlobalPartitionEndpointManagerForCircuitBreakerAsync)
-from ._read_items_helper_async import ReadItemsHelperAsync
-
+from azure.cosmos.aio._global_partition_endpoint_manager_per_partition_automatic_failover_async import (
+ _GlobalPartitionEndpointManagerForPerPartitionAutomaticFailoverAsync)
from .. import _base as base
from .._base import _build_properties_cache
from .. import documents
@@ -79,6 +77,7 @@
from ._auth_policy_async import AsyncCosmosBearerTokenCredentialPolicy
from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy
from .._range_partition_resolver import RangePartitionResolver
+from ._read_items_helper_async import ReadItemsHelperAsync
@@ -180,7 +179,7 @@ def __init__( # pylint: disable=too-many-statements
# Keeps the latest response headers from the server.
self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict()
self.UseMultipleWriteLocations = False
- self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreakerAsync(self)
+ self._global_endpoint_manager = _GlobalPartitionEndpointManagerForPerPartitionAutomaticFailoverAsync(self)
retry_policy = None
if isinstance(self.connection_policy.ConnectionRetryConfiguration, AsyncHTTPPolicy):
@@ -472,6 +471,10 @@ async def GetDatabaseAccount(
self.UseMultipleWriteLocations = (
self.connection_policy.UseMultipleWriteLocations and database_account._EnableMultipleWritableLocations
)
+
+ if Constants.EnablePerPartitionFailoverBehavior in result:
+ database_account._EnablePerPartitionFailoverBehavior = result[Constants.EnablePerPartitionFailoverBehavior]
+
return database_account
async def health_check(
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py
index 02e88fc242ea..954a324cc13f 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py
@@ -36,7 +36,7 @@
if TYPE_CHECKING:
from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection
-
+# cspell:ignore ppcb
# pylint: disable=protected-access
class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager):
"""
@@ -94,16 +94,17 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> Optional[Part
def is_circuit_breaker_applicable(self, request: RequestObject) -> bool:
return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request)
- async def record_failure(
+ async def record_ppcb_failure(
self,
- request: RequestObject
- ) -> None:
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None) -> None:
if self.is_circuit_breaker_applicable(request):
- pk_range_wrapper = await self.create_pk_range_wrapper(request)
+ if pk_range_wrapper is None:
+ pk_range_wrapper = await self.create_pk_range_wrapper(request)
if pk_range_wrapper:
self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper)
- def resolve_service_endpoint_for_partition(
+ def _resolve_service_endpoint_for_partition_circuit_breaker(
self,
request: RequestObject,
pk_range_wrapper: Optional[PartitionKeyRangeWrapper]
@@ -114,11 +115,12 @@ def resolve_service_endpoint_for_partition(
pk_range_wrapper)
return self._resolve_service_endpoint(request)
- async def record_success(
+ async def record_ppcb_success(
self,
- request: RequestObject
- ) -> None:
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None) -> None:
if self.is_circuit_breaker_applicable(request):
- pk_range_wrapper = await self.create_pk_range_wrapper(request)
+ if pk_range_wrapper is None:
+ pk_range_wrapper = await self.create_pk_range_wrapper(request)
if pk_range_wrapper:
self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_per_partition_automatic_failover_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_per_partition_automatic_failover_async.py
new file mode 100644
index 000000000000..c96b46ca46b3
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_per_partition_automatic_failover_async.py
@@ -0,0 +1,242 @@
+# The MIT License (MIT)
+# Copyright (c) 2025 Microsoft Corporation
+
+"""Class for global endpoint manager for per partition automatic failover. This class inherits the circuit breaker
+endpoint manager, since enabling per partition automatic failover also enables the circuit breaker logic.
+"""
+import logging
+import threading
+import os
+
+from typing import TYPE_CHECKING, Optional
+
+from azure.cosmos.http_constants import ResourceType
+from azure.cosmos._constants import _Constants as Constants
+from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import \
+ _GlobalPartitionEndpointManagerForCircuitBreakerAsync
+from azure.cosmos.documents import _OperationType
+from azure.cosmos._partition_health_tracker import _PPAFPartitionThresholdsTracker
+from azure.cosmos._request_object import RequestObject
+from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper
+
+if TYPE_CHECKING:
+ from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection
+ from azure.cosmos._location_cache import RegionalRoutingContext
+
+logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForPerPartitionAutomaticFailover")
+
+# pylint: disable=name-too-long, protected-access, too-many-nested-blocks
+#cspell:ignore PPAF, ppaf, ppcb
+
+class PartitionLevelFailoverInfo:
+ """
+ Holds information about the partition level regional failover.
+ Used to track the partition key range and the regions where it is available.
+ """
+ def __init__(self) -> None:
+ self.unavailable_regional_endpoints: dict[str, "RegionalRoutingContext"] = {}
+ self._lock = threading.Lock()
+ self.current_region: Optional[str] = None
+
+ def try_move_to_next_location(
+ self,
+ available_account_regional_endpoints: dict[str, "RegionalRoutingContext"],
+ endpoint_region: str,
+ request: RequestObject) -> bool:
+ """
+ Tries to move to the next available regional endpoint for the partition key range.
+ :param Dict[str, RegionalRoutingContext] available_account_regional_endpoints: The available regional endpoints
+ :param str endpoint_region: The current regional endpoint
+ :param RequestObject request: The request object containing the routing context.
+ :return: True if the move was successful, False otherwise.
+ :rtype: bool
+ """
+ with self._lock:
+ if endpoint_region != self.current_region and self.current_region is not None:
+ regional_endpoint = available_account_regional_endpoints[self.current_region].primary_endpoint
+ request.route_to_location(regional_endpoint)
+ return True
+
+ for regional_endpoint in available_account_regional_endpoints:
+ if regional_endpoint == self.current_region:
+ continue
+
+ if regional_endpoint in self.unavailable_regional_endpoints:
+ continue
+
+ self.current_region = regional_endpoint
+ logger.warning("PPAF - Moving to next available regional endpoint: %s", self.current_region)
+ regional_endpoint = available_account_regional_endpoints[self.current_region].primary_endpoint
+ request.route_to_location(regional_endpoint)
+ return True
+
+ return False
+
+class _GlobalPartitionEndpointManagerForPerPartitionAutomaticFailoverAsync(
+ _GlobalPartitionEndpointManagerForCircuitBreakerAsync):
+ """
+ This internal class implements the logic for partition endpoint management for
+ geo-replicated database accounts.
+ """
+ def __init__(self, client: "CosmosClientConnection") -> None:
+ super(_GlobalPartitionEndpointManagerForPerPartitionAutomaticFailoverAsync, self).__init__(client)
+ self.partition_range_to_failover_info: dict[PartitionKeyRangeWrapper, PartitionLevelFailoverInfo] = {}
+ self.ppaf_thresholds_tracker = _PPAFPartitionThresholdsTracker()
+ self._threshold_lock = threading.Lock()
+
+ def is_per_partition_automatic_failover_enabled(self) -> bool:
+ if not self._database_account_cache or not self._database_account_cache._EnablePerPartitionFailoverBehavior:
+ return False
+ return True
+
+ def is_per_partition_automatic_failover_applicable(self, request: RequestObject) -> bool:
+ if not self.is_per_partition_automatic_failover_enabled():
+ return False
+
+ if not request:
+ return False
+
+ if (self.location_cache.can_use_multiple_write_locations_for_request(request)
+ or _OperationType.IsReadOnlyOperation(request.operation_type)):
+ return False
+
+ # if we have at most one region available in the account, we cannot do per partition automatic failover
+ available_regions = self.location_cache.account_read_regional_routing_contexts_by_location
+ if len(available_regions) <= 1:
+ return False
+
+ # if the request is not a non-query plan document request
+ # or if the request is not executing a stored procedure, return False
+ if (request.resource_type != ResourceType.Document and
+ request.operation_type != _OperationType.ExecuteJavaScript):
+ return False
+
+ return True
+
+ def try_ppaf_failover_threshold(
+ self,
+ pk_range_wrapper: "PartitionKeyRangeWrapper",
+ request: "RequestObject"):
+ """Verifies whether the per-partition failover threshold has been reached for consecutive errors. If so,
+ it marks the current region as unavailable for the given partition key range, and moves to the next available
+ region for the request.
+
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :param RequestObject request: The request object containing the routing context.
+ :returns: None
+ """
+ # If PPAF is enabled, we track consecutive failures for certain exceptions, and only fail over at a partition
+ # level after the threshold is reached
+ if request and self.is_per_partition_automatic_failover_applicable(request):
+ if (self.ppaf_thresholds_tracker.get_pk_failures(pk_range_wrapper)
+ >= int(os.environ.get(Constants.TIMEOUT_ERROR_THRESHOLD_PPAF,
+ Constants.TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT))):
+ # If the PPAF threshold is reached, we reset the count and mark the endpoint unavailable
+ # Once we mark the endpoint unavailable, the PPAF endpoint manager will try to move to the next
+ # available region for the partition key range
+ with self._threshold_lock:
+ # Check for count again, since a previous request may have now reset the count
+ if (self.ppaf_thresholds_tracker.get_pk_failures(pk_range_wrapper)
+ >= int(os.environ.get(Constants.TIMEOUT_ERROR_THRESHOLD_PPAF,
+ Constants.TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT))):
+ self.ppaf_thresholds_tracker.clear_pk_failures(pk_range_wrapper)
+ partition_level_info = self.partition_range_to_failover_info[pk_range_wrapper]
+ location = self.location_cache.get_location_from_endpoint(
+ str(request.location_endpoint_to_route))
+ logger.warning("PPAF - Failover threshold reached for partition key range: %s for region: %s", #pylint: disable=line-too-long
+ pk_range_wrapper, location)
+ regional_context = (self.location_cache.
+ account_read_regional_routing_contexts_by_location.
+ get(location).primary_endpoint)
+ partition_level_info.unavailable_regional_endpoints[location] = regional_context
+
+ def resolve_service_endpoint_for_partition(
+ self,
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper]
+ ) -> str:
+ """Resolves the endpoint to be used for the request. In a PPAF-enabled account, this method checks whether
+ the partition key range has any unavailable regions, and if so, it tries to move to the next available region.
+ If all regions are unavailable, it invalidates the cache and starts once again from the main write region in the
+ account configurations.
+
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :param RequestObject request: The request object containing the routing context.
+ :returns: The regional endpoint to be used for the request.
+ :rtype: str
+ """
+ if self.is_per_partition_automatic_failover_applicable(request) and pk_range_wrapper:
+ # If per partition automatic failover is applicable, we check partition unavailability
+ if pk_range_wrapper in self.partition_range_to_failover_info:
+ partition_failover_info = self.partition_range_to_failover_info[pk_range_wrapper]
+ if request.location_endpoint_to_route is not None:
+ endpoint_region = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route)
+ if endpoint_region in partition_failover_info.unavailable_regional_endpoints:
+ available_account_regional_endpoints = self.location_cache.account_read_regional_routing_contexts_by_location #pylint: disable=line-too-long
+ if (partition_failover_info.current_region is not None and
+ endpoint_region != partition_failover_info.current_region):
+ # this request has not yet seen there's an available region being used for this partition
+ regional_endpoint = available_account_regional_endpoints[
+ partition_failover_info.current_region].primary_endpoint
+ request.route_to_location(regional_endpoint)
+ else:
+ if (len(self.location_cache.account_read_regional_routing_contexts_by_location) ==
+ len(partition_failover_info.unavailable_regional_endpoints)):
+ # If no other region is available, we invalidate the cache and start once again
+ # from our main write region in the account configurations
+ logger.warning("All available regions for partition %s are unavailable."
+ " Refreshing cache.", pk_range_wrapper)
+ self.partition_range_to_failover_info[pk_range_wrapper] = PartitionLevelFailoverInfo()
+ request.clear_route_to_location()
+ else:
+ # If the current region is unavailable, we try to move to the next available region
+ partition_failover_info.try_move_to_next_location(
+ self.location_cache.account_read_regional_routing_contexts_by_location,
+ endpoint_region,
+ request)
+ else:
+ # Update the current regional endpoint to whatever the request is routing to
+ partition_failover_info.current_region = endpoint_region
+ else:
+ partition_failover_info = PartitionLevelFailoverInfo()
+ endpoint_region = self.location_cache.get_location_from_endpoint(
+ request.location_endpoint_to_route)
+ partition_failover_info.current_region = endpoint_region
+ self.partition_range_to_failover_info[pk_range_wrapper] = partition_failover_info
+ return self._resolve_service_endpoint_for_partition_circuit_breaker(request, pk_range_wrapper)
+
+ async def record_failure(self,
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None) -> None:
+ """Records a failure for the given partition key range and request.
+ :param RequestObject request: The request object containing the routing context.
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :return: None
+ """
+ if self.is_per_partition_automatic_failover_applicable(request):
+ if pk_range_wrapper is None:
+ pk_range_wrapper = await self.create_pk_range_wrapper(request)
+ if pk_range_wrapper:
+ self.ppaf_thresholds_tracker.add_failure(pk_range_wrapper)
+ else:
+ await self.record_ppcb_failure(request, pk_range_wrapper)
+
+ async def record_success(self,
+ request: RequestObject,
+ pk_range_wrapper: Optional[PartitionKeyRangeWrapper] = None) -> None:
+ """Records a success for the given partition key range and request, effectively clearing the failure count.
+ :param RequestObject request: The request object containing the routing context.
+ :param PartitionKeyRangeWrapper pk_range_wrapper: The wrapper containing the partition key range information
+ for the request.
+ :return: None
+ """
+ if self.is_per_partition_automatic_failover_applicable(request):
+ if pk_range_wrapper is None:
+ pk_range_wrapper = await self.create_pk_range_wrapper(request)
+ if pk_range_wrapper:
+ self.ppaf_thresholds_tracker.clear_pk_failures(pk_range_wrapper)
+ else:
+ await self.record_ppcb_success(request, pk_range_wrapper)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py
index be19eedc36b7..9f58a1b8d8aa 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py
@@ -29,7 +29,7 @@
from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError
from azure.core.pipeline.policies import AsyncRetryPolicy
-from .. import _default_retry_policy, _health_check_retry_policy
+from .. import _default_retry_policy, _health_check_retry_policy, _service_unavailable_retry_policy
from .. import _endpoint_discovery_retry_policy
from .. import _gone_retry_policy
from .. import _resource_throttle_retry_policy
@@ -47,6 +47,7 @@
# pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches
+# cspell:ignore ppaf, ppcb
# args [0] is the request object
# args [1] is the connection policy
@@ -66,11 +67,12 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg
:rtype: tuple of (dict, dict)
"""
pk_range_wrapper = None
- if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]):
+ if args and (global_endpoint_manager.is_per_partition_automatic_failover_applicable(args[0]) or
+ global_endpoint_manager.is_circuit_breaker_applicable(args[0])):
pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(args[0])
# instantiate all retry policies here to be applied for each request execution
endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy(
- client.connection_policy, global_endpoint_manager, *args
+ client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args
)
health_check_retry_policy = _health_check_retry_policy.HealthCheckRetryPolicy(
client.connection_policy,
@@ -96,8 +98,11 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg
service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy(
client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args,
)
+ service_unavailable_retry_policy = _service_unavailable_retry_policy._ServiceUnavailableRetryPolicy(
+ client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args)
# Get Logger
logger = kwargs.get("logger", logging.getLogger("azure.cosmos._retry_utility_async"))
+
# HttpRequest we would need to modify for Container Recreate Retry Policy
request = None
if args and len(args) > 3:
@@ -115,7 +120,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg
try:
if args:
result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs)
- await global_endpoint_manager.record_success(args[0])
+ await global_endpoint_manager.record_success(args[0], pk_range_wrapper)
else:
result = await ExecuteFunctionAsync(function, *args, **kwargs)
if not client.last_response_headers:
@@ -198,13 +203,17 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg
if retry_policy.should_update_throughput_link(request.body, cached_container):
new_body = retry_policy._update_throughput_link(request.body)
request.body = new_body
-
retry_policy.container_rid = cached_container["_rid"]
request.headers[retry_policy._intended_headers] = retry_policy.container_rid
+ elif e.status_code == StatusCodes.SERVICE_UNAVAILABLE:
+ if args:
+ # record the failure for circuit breaker tracking
+ await global_endpoint_manager.record_ppcb_failure(args[0], pk_range_wrapper)
+ retry_policy = service_unavailable_retry_policy
elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR:
- # record the failure for circuit breaker tracking
if args:
- await global_endpoint_manager.record_failure(args[0])
+ # record the failure for ppaf/circuit breaker tracking
+ await global_endpoint_manager.record_failure(args[0], pk_range_wrapper)
retry_policy = timeout_failover_retry_policy
else:
retry_policy = defaultRetry_policy
@@ -252,12 +261,12 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg
_handle_service_request_retries(client, service_request_retry_policy, e, *args)
else:
if args:
- await global_endpoint_manager.record_failure(args[0])
+ await global_endpoint_manager.record_failure(args[0], pk_range_wrapper)
_handle_service_response_retries(request, client, service_response_retry_policy, e, *args)
# in case customer is not using aiohttp
except ImportError:
if args:
- await global_endpoint_manager.record_failure(args[0])
+ await global_endpoint_manager.record_failure(args[0], pk_range_wrapper)
_handle_service_response_retries(request, client, service_response_retry_policy, e, *args)
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
index 35a48b5f0229..74576e29b6a1 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
@@ -95,7 +95,9 @@ def _build_connection_policy(kwargs: dict[str, Any]) -> ConnectionPolicy:
policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery)
policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations)
# TODO: Consider storing callback method instead, such as 'Supplier' in JAVA SDK
- policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations)
+ excluded_locations = kwargs.pop('excluded_locations', policy.ExcludedLocations)
+ if excluded_locations:
+ policy.ExcludedLocations = excluded_locations
policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations)
# SSL config
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py
index d3f28233a6fa..e4edfbb77e85 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py
@@ -79,6 +79,7 @@ def __init__(self) -> None:
self._WritableLocations: list[dict[str, str]] = []
self._ReadableLocations: list[dict[str, str]] = []
self._EnableMultipleWritableLocations = False
+ self._EnablePerPartitionFailoverBehavior = False
@property
def WritableLocations(self) -> list[dict[str, str]]:
diff --git a/sdk/cosmos/azure-cosmos/docs/ErrorCodesAndRetries.md b/sdk/cosmos/azure-cosmos/docs/ErrorCodesAndRetries.md
index 9018f3592e67..adea6378a21f 100644
--- a/sdk/cosmos/azure-cosmos/docs/ErrorCodesAndRetries.md
+++ b/sdk/cosmos/azure-cosmos/docs/ErrorCodesAndRetries.md
@@ -2,20 +2,20 @@
The Cosmos DB Python SDK has several default policies that will deal with retrying certain errors and exceptions. More information on these can be found below.
-| Status code | Cause of exception and retry behavior |
-| :--- |:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| 400 | For all operations:
- This exception is encountered when the request is invalid, which could be for any of the following reasons:
- Syntax error in query text
- Malformed JSON document for a write request
- Incorrectly formatted REST API request body etc.
- The client does NOT retry the request when a Bad Request (400) exception is thrown by the server.
|
-| 401 | For all operations: - This is an unauthorized exception due to invalid auth tokens being used for the request. The client does NOT retry requests when this exception is encountered.
|
-| 403 | - For Substatus 3 (Write Forbidden) and Substatus 1008 (Database Account Not Found):
- This exception occurs when a geo-replicated database account runs into writable/readable location changes (say, after a failover).
- This exception can occur regardless of the Consistency level set for the account.
- The client refreshes it's location endpoints and retries requests when the user has enabled endpoint discovery in their client (default behavior).
- For all other cases:
- The client does NOT retry requests when this exception is encountered.
|
+| Status code | Cause of exception and retry behavior |
+|:------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| 400 | For all operations: - This exception is encountered when the request is invalid, which could be for any of the following reasons:
- Syntax error in query text
- Malformed JSON document for a write request
- Incorrectly formatted REST API request body etc.
- The client does NOT retry the request when a Bad Request (400) exception is thrown by the server.
|
+| 401 | For all operations: - This is an unauthorized exception due to invalid auth tokens being used for the request. The client does NOT retry requests when this exception is encountered.
|
+| 403 | - For Substatus 3 (Write Forbidden) and Substatus 1008 (Database Account Not Found):
- This exception occurs when a geo-replicated database account runs into writable/readable location changes (say, after a failover).
- This exception can occur regardless of the Consistency level set for the account.
- The client refreshes it's location endpoints and retries requests when the user has enabled endpoint discovery in their client (default behavior).
- For all other cases:
- The client does NOT retry requests when this exception is encountered.
|
| 404/1002 | - For write operations:
- If multiple write locations are enabled for the account, the SDK will fetch the write endpoints and retry once per each of these.
- The client refreshes it's location endpoints and retries requests when the user has enabled endpoint discovery in their client (default behavior).
- If the account does not have multiple write locations enabled, the SDK will retry only once in the account primary region.
- For read operations:
- If multiple write locations are enabled for the account, the SDK will fetch the read endpoints and retry once per each of these.
- The client refreshes it's location endpoints and retries requests when the user has enabled endpoint discovery in their client (default behavior).
- If the account does not have multiple write locations enabled, the SDK will retry only once in the account primary region.
|
-| 408 | - For Write Operations:
- Timeout exceptions can be encountered by both the client as well as the server. Server-side timeout exceptions are not retried for write operations as it is not possible to determine if the write was in fact successfully committed on the server. For a client-generated timeout exception, either the request was sent over the wire to the server by the client and the network request timeout exceeded while waiting for a response, or the request was not sent over the wire to the server which resulted in a client-generated timeout. The client does NOT retry for either.
- For Query and Point Read Operations:
- The SDK will retry on the next preferred region, if any is available.
|
-| 409 | - For Write Operations:
- This exception occurs when an attempt is made by the application to Create/Insert an Item that already exists.
- This exception can occur regardless of the Consistency level set for the account.
- This exception can occur for write operations when an attempt is made to create an existing item or when a unique key constraint violation occurs.
- The client does NOT retry on Conflict exceptions
- For Query and Point Read Operations:
- N/A as this exception is only encountered for Create/Insert operations.
|
+| 408 | - For Write Operations:
- Timeout exceptions can be encountered by both the client as well as the server. Server-side timeout exceptions are not retried for write operations as it is not possible to determine if the write was in fact successfully committed on the server. For a client-generated timeout exception, either the request was sent over the wire to the server by the client and the network request timeout exceeded while waiting for a response, or the request was not sent over the wire to the server which resulted in a client-generated timeout. The client does NOT retry for either.
- For Query and Point Read Operations:
- The SDK will retry on the next preferred region, if any is available.
|
+| 409 | - For Write Operations:
- This exception occurs when an attempt is made by the application to Create/Insert an Item that already exists.
- This exception can occur regardless of the Consistency level set for the account.
- This exception can occur for write operations when an attempt is made to create an existing item or when a unique key constraint violation occurs.
- The client does NOT retry on Conflict exceptions
- For Query and Point Read Operations:
- N/A as this exception is only encountered for Create/Insert operations.
|
| 410/1002 | - For all operations:
- This exception occurs when a partition is split (or merged in the future) and no longer exists, and can occur regardless of the Consistency level set for the account.
- The SDK will refresh its partition key range cache and trigger a single retry, fetching the new ranges from the gateway once it finds an empty cache.
|
-| 412 | - For Write Operations:
- This exception is encountered when the etag that is sent to the server for validation prior to updating an Item, does not match the etag of the Item on the server.
- The client does NOT retry this operation locally or against any of the remote regions for the account as retries would not help alleviate the etag mismatch.
- The application would need to trigger a retry by first reading the Item, fetching the latest etag and issuing the Upsert/Replace operation.
- This operation can continue to fail with the same exception when multiple updates are executed concurrently for the same Item.
- An upper bound on the number of retries before handing off the Item to a dead letter queue should be implemented by the application.
- For Query and point read Operations:
- N/A as this exception is only encountered for Create/Insert/Replace/Upsert operations.
|
-| 429 | For all Operations: - By default, the client retries the request for a maximum of 9 times (or for a maximum of 30 seconds, whichever limit is reached first).
- The client can also be initialized with a custom retry policy, which overrides the two limits mentioned above.
- After all the retries are exhausted, the client bubbles up the exception to the application.
- **For a multi-region account**, the client does NOT retry the request against a remote region for the account.
- When the application receives a Request Rate too large exception (429), the application would need to instrument its own retry logic and dead letter queues.
|
-| 449 | - For Write Operations:
- This exception is encountered when a resource is concurrently updated on the server, which can happen due to concurrent writes, user triggered while conflicts are concurrently being resolved etc.
- Only one update can be executed at a time per item. The other concurrent requests will fail with a Concurrent Execution Exception (449).
- The client does NOT retry requests that failed with a 449.
- For Query and point read Operations:
- N/A as this exception is only encountered for Create/Insert/Replace/Upsert operations.
|
-| 500 | For all Operations: - The occurrence of an Invalid Exception (500) is extremely rare, and the client will retry a request that encounters this exception on the next preferred regions.
|
-| 503 | When a Service Unavailable exception is encountered: - The request will be retried by the SDK on the next preferred regions. |
+| 412 |
- For Write Operations:
- This exception is encountered when the etag that is sent to the server for validation prior to updating an Item, does not match the etag of the Item on the server.
- The client does NOT retry this operation locally or against any of the remote regions for the account as retries would not help alleviate the etag mismatch.
- The application would need to trigger a retry by first reading the Item, fetching the latest etag and issuing the Upsert/Replace operation.
- This operation can continue to fail with the same exception when multiple updates are executed concurrently for the same Item.
- An upper bound on the number of retries before handing off the Item to a dead letter queue should be implemented by the application.
- For Query and point read Operations:
- N/A as this exception is only encountered for Create/Insert/Replace/Upsert operations.
|
+| 429 | For all Operations: - By default, the client retries the request for a maximum of 9 times (or for a maximum of 30 seconds, whichever limit is reached first).
- The client can also be initialized with a custom retry policy, which overrides the two limits mentioned above.
- After all the retries are exhausted, the client bubbles up the exception to the application.
- **For a multi-region account**, the client does NOT retry the request against a remote region for the account.
- When the application receives a Request Rate too large exception (429), the application would need to instrument its own retry logic and dead letter queues.
|
+| 449 | - For Write Operations:
- This exception is encountered when a resource is concurrently updated on the server, which can happen due to concurrent writes, user triggered while conflicts are concurrently being resolved etc.
- Only one update can be executed at a time per item. The other concurrent requests will fail with a Concurrent Execution Exception (449).
- The client does NOT retry requests that failed with a 449.
- For Query and point read Operations:
- N/A as this exception is only encountered for Create/Insert/Replace/Upsert operations.
|
+| 500 | - For Write Operations:
- The client does NOT retry write requests.
- For Read Operations:
- The request will be retried by the SDK on the next preferred regions.
|
+| 503 | When a Service Unavailable exception is encountered, for all Operations: - The request will be retried by the SDK on the next preferred regions. |
### Connection Issues Retry Flow And Marking Unavailable
diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini
index aabe78b51f08..a5e006ea027e 100644
--- a/sdk/cosmos/azure-cosmos/pytest.ini
+++ b/sdk/cosmos/azure-cosmos/pytest.ini
@@ -7,3 +7,4 @@ markers =
cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled.
cosmosCircuitBreaker: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled.
cosmosCircuitBreakerMultiRegion: marks tests running on Cosmos DB live account with one write region and multiple read regions and per partition circuit breaker enabled.
+ cosmosPerPartitionAutomaticFailover: marks tests running on Cosmos DB live account with one write region and multiple read regions and per partition automatic failover enabled.
diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py
index 51b18738729e..f71a21003c98 100644
--- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py
+++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py
@@ -26,13 +26,14 @@
import logging
import sys
from time import sleep
-from typing import Callable, Optional, Any, MutableMapping
+from typing import Callable, Optional, Any, MutableMapping, Mapping, Tuple, Sequence
from azure.core.pipeline.transport import HttpRequest, HttpResponse
from azure.core.pipeline.transport._requests_basic import RequestsTransport, RequestsTransportResponse
from requests import Session
from azure.cosmos import documents
+from azure.cosmos._constants import _Constants as Constants
import test_config
from azure.cosmos.exceptions import CosmosHttpResponseError
@@ -63,8 +64,29 @@ def error_with_counter(self, error: Exception) -> Exception:
self.counters[ERROR_WITH_COUNTER] += 1
return error
- def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], Exception]):
- self.faults.append({"predicate": predicate, "apply": fault_factory})
+ def add_fault(self,
+ predicate: Callable[[HttpRequest], bool],
+ fault_factory: Callable[[HttpRequest], Exception],
+ max_inner_count: Optional[int] = None,
+ after_max_count: Optional[Callable[[HttpRequest], RequestsTransportResponse]] = None):
+ """ Adds a fault to the transport that will be applied when the predicate matches the request.
+ :param Callable predicate: A callable that takes an HttpRequest and returns True if the fault should be applied.
+ :param Callable fault_factory: A callable that takes an HttpRequest and returns an Exception to be raised.
+ :param int max_inner_count: Optional maximum number of times the fault can be applied for one request.
+ If None, the fault will be applied every time the predicate matches.
+ :param Callable after_max_count: Optional callable that takes an HttpRequest and returns a
+ RequestsTransportResponse. Used to return a different response after the maximum number of faults has
+ been applied. Can only be used if `max_inner_count` is not None.
+ """
+ if max_inner_count is not None:
+ if after_max_count is not None:
+ self.faults.append({"predicate": predicate, "apply": fault_factory, "after_max_count": after_max_count,
+ "max_count": max_inner_count, "current_count": 0})
+ else:
+ self.faults.append({"predicate": predicate, "apply": fault_factory,
+ "max_count": max_inner_count, "current_count": 0})
+ else:
+ self.faults.append({"predicate": predicate, "apply": fault_factory})
def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], RequestsTransportResponse]], RequestsTransportResponse]):
self.responseTransformations.append({
@@ -85,6 +107,16 @@ def send(self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, st
# find the first fault Factory with matching predicate if any
first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request))
if first_fault_factory:
+ if "max_count" in first_fault_factory:
+ FaultInjectionTransport.logger.info(f"Found fault factory with max count {first_fault_factory['max_count']}")
+ if first_fault_factory["current_count"] >= first_fault_factory["max_count"]:
+ first_fault_factory["current_count"] = 0 # reset counter
+ if "after_max_count" in first_fault_factory:
+ FaultInjectionTransport.logger.info("Max count reached, returning after_max_count")
+ return first_fault_factory["after_max_count"]
+ FaultInjectionTransport.logger.info("Max count reached, skipping fault injection")
+ return super().send(request, proxies=proxies, **kwargs)
+ first_fault_factory["current_count"] += 1
FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection")
injected_error = first_fault_factory["apply"](request)
FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error))
@@ -132,12 +164,21 @@ def print_call_stack():
frame = frame.f_back
@staticmethod
- def predicate_req_payload_contains_id(r: HttpRequest, id_value: str):
+ def predicate_req_payload_contains_id(r: HttpRequest, id_value: str) -> bool:
if r.body is None:
return False
return '"id":"{}"'.format(id_value) in r.body
+ @staticmethod
+ def predicate_req_payload_contains_field(r: HttpRequest, field_name: str, field_value: Optional[str]) -> bool:
+ if r.body is None:
+ return False
+ if field_value is None:
+ return '"{}":"'.format(field_name) in r.body
+ else:
+ return '"{}":"{}"'.format(field_name, field_value) in r.body
+
@staticmethod
def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool:
return (FaultInjectionTransport.predicate_url_contains_id(r, id_value)
@@ -163,15 +204,8 @@ def predicate_is_resource_type(r: HttpRequest, resource_type: str) -> bool:
@staticmethod
def predicate_is_operation_type(r: HttpRequest, operation_type: str) -> bool:
is_operation_type = r.headers.get(HttpHeaders.ThinClientProxyOperationType) == operation_type
-
return is_operation_type
- @staticmethod
- def predicate_is_resource_type(r: HttpRequest, resource_type: str) -> bool:
- is_resource_type = r.headers.get(HttpHeaders.ThinClientProxyResourceType) == resource_type
-
- return is_resource_type
-
@staticmethod
def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool:
is_write_document_operation = documents._OperationType.IsWriteOperation(
@@ -225,7 +259,8 @@ def error_service_response() -> Exception:
def transform_topology_swr_mrr(
write_region_name: str,
read_region_name: str,
- inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse:
+ inner: Callable[[], RequestsTransportResponse],
+ enable_per_partition_failover: bool = False) -> RequestsTransportResponse:
response = inner()
if not FaultInjectionTransport.predicate_is_database_account_call(response.request):
@@ -241,6 +276,28 @@ def transform_topology_swr_mrr(
writable_locations[0]["name"] = write_region_name
readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host})
FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result))
+ # TODO: need to verify below behavior against actual Cosmos DB service response
+ if enable_per_partition_failover:
+ result["enablePerPartitionFailoverBehavior"] = True
+ request: HttpRequest = response.request
+ return FaultInjectionTransport.MockHttpResponse(request, 200, result)
+
+ return response
+
+ @staticmethod
+ def transform_topology_ppaf_enabled( # cspell:disable-line
+ inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse:
+
+ response = inner()
+ if not FaultInjectionTransport.predicate_is_database_account_call(response.request):
+ return response
+
+ data = response.body()
+ if response.status_code == 200 and data:
+ data = data.decode("utf-8")
+ result = json.loads(data)
+ result[Constants.EnablePerPartitionFailoverBehavior] = True
+ FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result))
request: HttpRequest = response.request
return FaultInjectionTransport.MockHttpResponse(request, 200, result)
@@ -283,8 +340,25 @@ def transform_topology_mwr(
return response
+ class MockHttpRequest(HttpRequest):
+ def __init__(
+ self,
+ url: str,
+ method: str = "GET",
+ headers: Optional[Mapping[str, str]] = None,
+ files: Optional[Any] = None,
+ data: Optional[Any] = None,
+ ) -> None:
+ self.method = method
+ self.url = url
+ self.headers: Optional[MutableMapping[str, str]] = headers
+ self.files: Optional[Any] = files
+ self.data: Optional[Any] = data
+ self.multipart_mixed_info: Optional[
+ Tuple[Sequence[Any], Sequence[Any], Optional[str], dict[str, Any]]] = None
+
class MockHttpResponse(RequestsTransportResponse):
- def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, Any]]):
+ def __init__(self, request: HttpRequest, status_code: int, content: Optional[Any] = None):
self.request: HttpRequest = request
# This is actually never None, and set by all implementations after the call to
# __init__ of this class. This class is also a legacy impl, so it's risky to change it
diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py
index 305d62fe8c2d..f487a6180aff 100644
--- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py
+++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py
@@ -60,8 +60,29 @@ async def error_with_counter(self, error: Exception) -> Exception:
self.counters[ERROR_WITH_COUNTER] += 1
return error
- def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], Awaitable[Exception]]):
- self.faults.append({"predicate": predicate, "apply": fault_factory})
+ def add_fault(self,
+ predicate: Callable[[HttpRequest], bool],
+ fault_factory: Callable[[HttpRequest], Awaitable[Exception]],
+ max_inner_count: Optional[int] = None,
+ after_max_count: Optional[Callable[[HttpRequest], AioHttpTransportResponse]] = None):
+ """ Adds a fault to the transport that will be applied when the predicate matches the request.
+ :param Callable predicate: A callable that takes an HttpRequest and returns True if the fault should be applied.
+ :param Callable fault_factory: A callable that takes an HttpRequest and returns an Exception to be raised.
+ :param int max_inner_count: Optional maximum number of times the fault can be applied for one request.
+ If None, the fault will be applied every time the predicate matches.
+ :param Callable after_max_count: Optional callable that takes an HttpRequest and returns a
+ AioHttpTransportResponse. Used to return a different response after the maximum number of faults has
+ been applied. Can only be used if `max_inner_count` is not None.
+ """
+ if max_inner_count is not None:
+ if after_max_count is not None:
+ self.faults.append({"predicate": predicate, "apply": fault_factory, "after_max_count": after_max_count,
+ "max_count": max_inner_count, "current_count": 0})
+ else:
+ self.faults.append({"predicate": predicate, "apply": fault_factory,
+ "max_count": max_inner_count, "current_count": 0})
+ else:
+ self.faults.append({"predicate": predicate, "apply": fault_factory})
def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], AioHttpTransportResponse]], AioHttpTransportResponse]):
self.responseTransformations.append({
@@ -82,6 +103,16 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Opt
# find the first fault Factory with matching predicate if any
first_fault_factory = FaultInjectionTransportAsync.__first_item(iter(self.faults), lambda f: f["predicate"](request))
if first_fault_factory:
+ if "max_count" in first_fault_factory:
+ FaultInjectionTransportAsync.logger.info(f"Found fault factory with max count {first_fault_factory['max_count']}")
+ if first_fault_factory["current_count"] >= first_fault_factory["max_count"]:
+ first_fault_factory["current_count"] = 0 # reset counter
+ if "after_max_count" in first_fault_factory:
+ FaultInjectionTransportAsync.logger.info("Max count reached, returning after_max_count")
+ return first_fault_factory["after_max_count"]
+ FaultInjectionTransportAsync.logger.info("Max count reached, skipping fault injection")
+ return await super().send(request, proxies=proxies, **config)
+ first_fault_factory["current_count"] += 1
FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.ApplyFaultInjection")
injected_error = await first_fault_factory["apply"](request)
FaultInjectionTransportAsync.logger.info("Found to-be-injected error {}".format(injected_error))
@@ -238,6 +269,26 @@ async def transform_topology_swr_mrr(
return response
+ @staticmethod
+ async def transform_topology_ppaf_enabled( # cspell:disable-line
+ inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse:
+
+ response = await inner()
+ if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request):
+ return response
+
+ data = response.body()
+ if response.status_code == 200 and data:
+ data = data.decode("utf-8")
+ result = json.loads(data)
+ # TODO: need to verify below behavior against actual Cosmos DB service response
+ result["enablePerPartitionFailoverBehavior"] = True
+ FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result))
+ request: HttpRequest = response.request
+ return FaultInjectionTransportAsync.MockHttpResponse(request, 200, result)
+
+ return response
+
@staticmethod
async def transform_topology_mwr(
first_region_name: str,
@@ -276,7 +327,7 @@ async def transform_topology_mwr(
return response
class MockHttpResponse(AioHttpTransportResponse):
- def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]):
+ def __init__(self, request: HttpRequest, status_code: int, content: Optional[Any]=None):
self.request: HttpRequest = request
# This is actually never None, and set by all implementations after the call to
# __init__ of this class. This class is also a legacy impl, so it's risky to change it
diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py
index 80aa9a060594..4f7b9a6d3254 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py
@@ -151,6 +151,8 @@ def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(self, set
@pytest.mark.parametrize("error", create_errors())
def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error):
+ if hasattr(error, "status_code") and error.status_code == 503:
+ pytest.skip("ServiceUnavailableError will do a cross-region retry, so it has to be special cased.")
error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error)
setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True)
fault_injection_container = custom_setup['col']
@@ -206,6 +208,8 @@ def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, set
@pytest.mark.parametrize("error", create_errors())
def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error):
+ if hasattr(error, "status_code") and error.status_code == 503:
+ pytest.skip("ServiceUnavailableError will do a cross-region retry, so it has to be special cased.")
error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error)
setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True)
fault_injection_container = custom_setup['col']
diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py
index fde8eb420838..c15051859c50 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py
@@ -156,6 +156,8 @@ async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm_asy
@pytest.mark.parametrize("error", create_errors())
async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm_async(self, setup_teardown, error):
+ if hasattr(error, "status_code") and error.status_code == 503:
+ pytest.skip("ServiceUnavailableError will do a cross-region retry, so it has to be special cased.")
error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error))
setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True)
fault_injection_container = custom_setup['col']
@@ -212,6 +214,8 @@ async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm_asy
@pytest.mark.parametrize("error", create_errors())
async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm_async(self, setup_teardown, error):
+ if hasattr(error, "status_code") and error.status_code == 503:
+ pytest.skip("ServiceUnavailableError will do a cross-region retry, so it has to be special cased.")
error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error))
setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True)
fault_injection_container = custom_setup['col']
diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py
index d54bafc32936..a385a2ca3e1f 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py
@@ -441,6 +441,7 @@ def test_delete_item(self, test_data):
MOCK_HANDLER.reset()
# API call: delete_item
+ container.upsert_item(body)
if request_excluded_locations is None:
container.delete_item(item_id, PARTITION_KEY_VALUES)
else:
diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py
index 090717519222..0a8f4c519cbb 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py
@@ -218,9 +218,9 @@ def test_get_applicable_regional_endpoints_excluded_regions(self, test_type):
location_cache.perform_on_database_account_read(database_account)
# Init requests and set excluded regions on requests
- write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None)
+ write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, {})
write_doc_request.excluded_locations = excluded_locations_on_requests
- read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read, None)
+ read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read, {})
read_doc_request.excluded_locations = excluded_locations_on_requests
# Test if read endpoints were correctly filtered on client level
@@ -250,7 +250,7 @@ def test_set_excluded_locations_for_requests(self):
options: Mapping[str, Any] = {"excludedLocations": excluded_locations}
expected_excluded_locations = excluded_locations
- read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None)
+ read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, {})
read_doc_request.set_excluded_location_from_options(options)
actual_excluded_locations = read_doc_request.excluded_locations
assert actual_excluded_locations == expected_excluded_locations
diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_automatic_failover.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_automatic_failover.py
new file mode 100644
index 000000000000..437c25556e05
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_automatic_failover.py
@@ -0,0 +1,279 @@
+# The MIT License (MIT)
+# Copyright (c) Microsoft Corporation. All rights reserved.
+import unittest
+import uuid
+
+import pytest
+import test_config
+from azure.core.exceptions import ServiceResponseError
+from azure.cosmos import CosmosClient
+from azure.cosmos.exceptions import CosmosHttpResponseError
+from _fault_injection_transport import FaultInjectionTransport
+from test_per_partition_circuit_breaker_mm import (REGION_1, REGION_2, PK_VALUE, BATCH,
+ write_operations_errors_and_boolean, perform_write_operation)
+
+# cspell:disable
+
+def create_failover_errors():
+ errors = []
+ error_codes = [403, 503]
+ for error_code in error_codes:
+ errors.append(CosmosHttpResponseError(
+ status_code=error_code,
+ message="Some injected error.",
+ sub_status=3))
+ return errors
+
+def create_threshold_errors():
+ errors = []
+ error_codes = [408, 500, 502, 504]
+ for error_code in error_codes:
+ errors.append(CosmosHttpResponseError(
+ status_code=error_code,
+ message="Some injected error."))
+ errors.append(ServiceResponseError(message="Injected Service Response Error."))
+ return errors
+
+# These tests assume that the configured live account has one main write region and one secondary read region.
+
+@pytest.mark.cosmosPerPartitionAutomaticFailover
+class TestPerPartitionAutomaticFailover:
+ host = test_config.TestConfig.host
+ master_key = test_config.TestConfig.masterKey
+ TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
+ TEST_CONTAINER_MULTI_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID
+
+ def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs):
+ regions = [REGION_1, REGION_2]
+ container_id = kwargs.pop("container_id", None)
+ exclude_client_regions = kwargs.pop("exclude_client_regions", False)
+ excluded_regions = []
+ if exclude_client_regions:
+ excluded_regions = [REGION_2]
+ if not container_id:
+ container_id = self.TEST_CONTAINER_MULTI_PARTITION_ID
+ client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session",
+ preferred_locations=regions,
+ excluded_locations=excluded_regions,
+ transport=custom_transport, **kwargs)
+ db = client.get_database_client(self.TEST_DATABASE_ID)
+ container = db.get_container_client(container_id)
+ return {"client": client, "db": db, "col": container}
+
+ def setup_info(self, error=None, max_count=None, is_batch=False, exclude_client_regions=False, session_error=False, **kwargs):
+ custom_transport = FaultInjectionTransport()
+ # two documents targeted to same partition, one will always fail and the other will succeed
+ doc_fail_id = str(uuid.uuid4())
+ doc_success_id = str(uuid.uuid4())
+ predicate = lambda r: (FaultInjectionTransport.predicate_req_for_document_with_id(r, doc_fail_id) and
+ FaultInjectionTransport.predicate_is_write_operation(r, "com"))
+ # The MockRequest only gets used to create the MockHttpResponse
+ mock_request = FaultInjectionTransport.MockHttpRequest(url=self.host)
+ if is_batch:
+ success_response = FaultInjectionTransport.MockHttpResponse(mock_request, 200, [{"statusCode": 200}],)
+ else:
+ success_response = FaultInjectionTransport.MockHttpResponse(mock_request, 200)
+ if error:
+ custom_transport.add_fault(predicate=predicate, fault_factory=error, max_inner_count=max_count,
+ after_max_count=success_response)
+ if session_error:
+ read_predicate = lambda r: (FaultInjectionTransport.predicate_is_operation_type(r, "Read")
+ and FaultInjectionTransport.predicate_req_for_document_with_id(r, doc_fail_id))
+ read_error = CosmosHttpResponseError(
+ status_code=404,
+ message="Some injected error.",
+ sub_status=1002)
+ error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, read_error)
+ success_response = FaultInjectionTransport.MockHttpResponse(mock_request, 200)
+ custom_transport.add_fault(predicate=read_predicate, fault_factory=error_lambda, max_inner_count=max_count,
+ after_max_count=success_response)
+ is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r)
+ # Set the database account response to have PPAF enabled
+ ppaf_enabled_database_account = \
+ lambda r, inner: FaultInjectionTransport.transform_topology_ppaf_enabled(inner=inner)
+ custom_transport.add_response_transformation(
+ is_get_account_predicate,
+ ppaf_enabled_database_account)
+ setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host,
+ exclude_client_regions=exclude_client_regions, **kwargs)
+ custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host,
+ exclude_client_regions=exclude_client_regions, **kwargs)
+ return setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate
+
+ @pytest.mark.parametrize("write_operation, error, exclude_regions", write_operations_errors_and_boolean(create_failover_errors()))
+ def test_ppaf_partition_info_cache_and_routing(self, write_operation, error, exclude_regions):
+ # This test validates that the partition info cache is updated correctly upon failures, and that the
+ # per-partition automatic failover logic routes requests to the next available regional endpoint on 403.3 errors.
+ # We also verify that this logic is unaffected by user excluded regions, since write-region routing is entirely
+ # taken care of on the service.
+ error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error)
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, 1,
+ write_operation == BATCH, exclude_client_regions=exclude_regions)
+ container = setup['col']
+ fault_injection_container = custom_setup['col']
+ global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager
+
+ # Create a document to populate the per-partition GEM partition range info cache
+ fault_injection_container.create_item(body={'id': doc_success_id, 'pk': PK_VALUE,
+ 'name': 'sample document', 'key': 'value'})
+ pk_range_wrapper = list(global_endpoint_manager.partition_range_to_failover_info.keys())[0]
+ initial_region = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper].current_region
+
+ # Based on our configuration, we should have had one error followed by a success - marking only the previous endpoint as unavailable
+ perform_write_operation(
+ write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the partition is marked as unavailable, and that the current regional endpoint is not the same
+ assert len(partition_info.unavailable_regional_endpoints) == 1
+ assert initial_region in partition_info.unavailable_regional_endpoints
+ assert initial_region != partition_info.current_region # west us 3 != west us
+
+ # Now we run another request to see how the cache gets updated
+ perform_write_operation(
+ write_operation,
+ container,
+ fault_injection_container,
+ str(uuid.uuid4()),
+ PK_VALUE)
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the cache is empty, since the request going to the second regional endpoint failed
+ # Once we reach the point of all available regions being marked as unavailable, the cache is cleared
+ assert len(partition_info.unavailable_regional_endpoints) == 0
+ assert initial_region not in partition_info.unavailable_regional_endpoints
+ assert partition_info.current_region is None
+
+
+ @pytest.mark.parametrize("write_operation, error, exclude_regions", write_operations_errors_and_boolean(create_threshold_errors()))
+ def test_ppaf_partition_thresholds_and_routing(self, write_operation, error, exclude_regions):
+ # This test validates the consecutive failures logic is properly handled for per-partition automatic failover,
+ # and that the per-partition automatic failover logic routes requests to the next available regional endpoint
+ # after enough consecutive failures have occurred. We also verify that this logic is unaffected by user excluded
+ # regions, since write-region routing is entirely taken care of on the service.
+ error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error)
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = self.setup_info(error=error_lambda,
+ exclude_client_regions=exclude_regions)
+ container = setup['col']
+ fault_injection_container = custom_setup['col']
+ global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager
+
+ # Create a document to populate the per-partition GEM partition range info cache
+ fault_injection_container.create_item(body={'id': doc_success_id, 'pk': PK_VALUE,
+ 'name': 'sample document', 'key': 'value'})
+ pk_range_wrapper = list(global_endpoint_manager.partition_range_to_failover_info.keys())[0]
+ initial_region = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper].current_region
+
+ consecutive_failures = 6
+ for i in range(consecutive_failures):
+ # We perform the write operation multiple times to check the consecutive failures logic
+ with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info:
+ perform_write_operation(write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ assert exc_info.value == error
+ # Verify that the threshold for consecutive failures is updated
+ pk_range_wrappers = list(global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count.keys())
+ assert len(pk_range_wrappers) == 1
+ failure_count = global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count[pk_range_wrappers[0]]
+ assert failure_count == consecutive_failures
+
+ # Verify that a single success to the same partition resets the consecutive failures count
+ perform_write_operation(write_operation,
+ container,
+ fault_injection_container,
+ str(uuid.uuid4()),
+ PK_VALUE)
+
+ failure_count = global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count.get(pk_range_wrappers[0], 0)
+ assert failure_count == 0
+
+ # Run enough failed requests to the partition to trigger the failover logic
+ for i in range(12):
+ with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info:
+ perform_write_operation(write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ assert exc_info.value == error
+ # We should have marked the previous endpoint as unavailable after 10 successive failures
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the partition is marked as unavailable, and that the current regional endpoint is not the same
+ assert len(partition_info.unavailable_regional_endpoints) == 1
+ assert initial_region in partition_info.unavailable_regional_endpoints
+ assert initial_region != partition_info.current_region # west us 3 != west us
+
+ # 12 failures - 10 to trigger failover, 2 more to start counting again
+ failure_count = global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count[pk_range_wrappers[0]]
+ assert failure_count == 2
+
+ @pytest.mark.parametrize("write_operation, error, exclude_regions", write_operations_errors_and_boolean(create_failover_errors()))
+ def test_ppaf_session_unavailable_retry(self, write_operation, error, exclude_regions):
+ # Account config has 2 regions: West US 3 (A) and West US (B). This test validates that after marking the write
+ # region (A) as unavailable, the next request is retried to the read region (B) and succeeds. The next read request
+ # should see that the write region (A) is unavailable for the partition, and should retry to the read region (B) as well.
+ # We also verify that this logic is unaffected by user excluded regions, since write-region routing is entirely
+ # taken care of on the service.
+ error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error)
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, max_count=1,
+ is_batch=write_operation==BATCH,
+ session_error=True, exclude_client_regions=exclude_regions)
+ container = setup['col']
+ fault_injection_container = custom_setup['col']
+ global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager
+
+ # Create a document to populate the per-partition GEM partition range info cache
+ fault_injection_container.create_item(body={'id': doc_success_id, 'pk': PK_VALUE,
+ 'name': 'sample document', 'key': 'value'})
+ pk_range_wrapper = list(global_endpoint_manager.partition_range_to_failover_info.keys())[0]
+ initial_region = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper].current_region
+
+ # Verify the region that is being used for the read requests
+ read_response = fault_injection_container.read_item(doc_success_id, PK_VALUE)
+ uri = read_response.get_response_headers().get('Content-Location')
+ region = fault_injection_container.client_connection._global_endpoint_manager.location_cache.get_location_from_endpoint(uri)
+ assert region == REGION_1 # first preferred region
+
+ # Based on our configuration, we should have had one error followed by a success - marking only the previous endpoint as unavailable
+ perform_write_operation(
+ write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the partition is marked as unavailable, and that the current regional endpoint is not the same
+ assert len(partition_info.unavailable_regional_endpoints) == 1
+ assert initial_region in partition_info.unavailable_regional_endpoints
+ assert initial_region != partition_info.current_region # west us 3 != west us
+
+ # Now we run a read request that runs into a 404.1002 error, which should retry to the read region
+ # We verify that the read request was going to the correct region by using the raw_response_hook
+ fault_injection_container.read_item(doc_fail_id, PK_VALUE, raw_response_hook=session_retry_hook)
+
+ def test_ppaf_user_agent_feature_flag(self):
+ # Simple test to verify the user agent suffix is being updated with the relevant feature flags
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = self.setup_info()
+ fault_injection_container = custom_setup['col']
+ # Create a document to check the response headers
+ fault_injection_container.upsert_item(body={'id': doc_success_id, 'pk': PK_VALUE, 'name': 'sample document', 'key': 'value'},
+ raw_response_hook=ppaf_user_agent_hook)
+
+def session_retry_hook(raw_response):
+ if raw_response.http_request.headers.get('x-ms-thinclient-proxy-resource-type') != 'databaseaccount':
+ # This hook is used to verify the request routing that happens after the session retry logic
+ region_string = "-" + REGION_2.replace(' ', '').lower() + "."
+ assert region_string in raw_response.http_request.url
+
+def ppaf_user_agent_hook(raw_response):
+ # Used to verify the user agent feature flags
+ user_agent = raw_response.http_request.headers.get('user-agent')
+ assert user_agent.endswith('| F3')
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_automatic_failover_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_automatic_failover_async.py
new file mode 100644
index 000000000000..860727c18b0b
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_automatic_failover_async.py
@@ -0,0 +1,263 @@
+# The MIT License (MIT)
+# Copyright (c) Microsoft Corporation. All rights reserved.
+import unittest
+import uuid
+
+import asyncio
+
+import pytest
+from typing import Dict, Any, Optional
+
+import test_config
+from azure.core.pipeline.transport._aiohttp import AioHttpTransport
+from azure.core.exceptions import ServiceResponseError
+from azure.cosmos.exceptions import CosmosHttpResponseError
+from azure.cosmos.aio import CosmosClient
+from _fault_injection_transport import FaultInjectionTransport
+from _fault_injection_transport_async import FaultInjectionTransportAsync
+from test_per_partition_automatic_failover import create_failover_errors, create_threshold_errors, session_retry_hook, ppaf_user_agent_hook
+from test_per_partition_circuit_breaker_mm import REGION_1, REGION_2, PK_VALUE, BATCH, write_operations_errors_and_boolean
+from test_per_partition_circuit_breaker_mm_async import perform_write_operation
+
+#cspell:ignore PPAF, ppaf
+
+# These tests assume that the configured live account has one main write region and one secondary read region.
+
+@pytest.mark.cosmosPerPartitionAutomaticFailover
+@pytest.mark.asyncio
+class TestPerPartitionAutomaticFailoverAsync:
+ host = test_config.TestConfig.host
+ master_key = test_config.TestConfig.masterKey
+ TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
+ TEST_CONTAINER_MULTI_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID
+
+ async def setup_method_with_custom_transport(self, custom_transport: Optional[AioHttpTransport],
+ default_endpoint=host, read_first=False, **kwargs):
+ regions = [REGION_2, REGION_1] if read_first else [REGION_1, REGION_2]
+ container_id = kwargs.pop("container_id", None)
+ exclude_client_regions = kwargs.pop("exclude_client_regions", False)
+ excluded_regions = []
+ if exclude_client_regions:
+ excluded_regions = [REGION_2]
+ if not container_id:
+ container_id = self.TEST_CONTAINER_MULTI_PARTITION_ID
+ client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session",
+ preferred_locations=regions,
+ excluded_locations=excluded_regions,
+ transport=custom_transport, **kwargs)
+ db = client.get_database_client(self.TEST_DATABASE_ID)
+ container = db.get_container_client(container_id)
+ await client.__aenter__()
+ return {"client": client, "db": db, "col": container}
+
+ @staticmethod
+ async def cleanup_method(initialized_objects: Dict[str, Any]):
+ method_client: CosmosClient = initialized_objects["client"]
+ await method_client.close()
+
+ async def setup_info(self, error=None, max_count=None, is_batch=False, exclude_client_regions=False, session_error=False, **kwargs):
+ custom_transport = FaultInjectionTransportAsync()
+ # two documents targeted to same partition, one will always fail and the other will succeed
+ doc_fail_id = str(uuid.uuid4())
+ doc_success_id = str(uuid.uuid4())
+ predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, doc_fail_id) and
+ FaultInjectionTransportAsync.predicate_is_write_operation(r, "com"))
+ # The MockRequest only gets used to create the MockHttpResponse
+ mock_request = FaultInjectionTransport.MockHttpRequest(url=self.host)
+ if is_batch:
+ success_response = FaultInjectionTransportAsync.MockHttpResponse(mock_request, 200, [{"statusCode": 200}],)
+ else:
+ success_response = FaultInjectionTransportAsync.MockHttpResponse(mock_request, 200)
+ if error:
+ custom_transport.add_fault(predicate=predicate, fault_factory=error, max_inner_count=max_count,
+ after_max_count=success_response)
+ if session_error:
+ read_predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_operation_type(r, "Read")
+ and FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, doc_fail_id))
+ read_error = CosmosHttpResponseError(
+ status_code=404,
+ message="Some injected error.",
+ sub_status=1002)
+ error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, read_error))
+ success_response = FaultInjectionTransportAsync.MockHttpResponse(mock_request, 200)
+ custom_transport.add_fault(predicate=read_predicate, fault_factory=error_lambda, max_inner_count=max_count,
+ after_max_count=success_response)
+ is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r)
+ # Set the database account response to have PPAF enabled
+ ppaf_enabled_database_account = \
+ lambda r, inner: FaultInjectionTransportAsync.transform_topology_ppaf_enabled(inner=inner)
+ custom_transport.add_response_transformation(
+ is_get_account_predicate,
+ ppaf_enabled_database_account)
+ setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host,
+ exclude_client_regions=exclude_client_regions, **kwargs)
+ custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host,
+ exclude_client_regions=exclude_client_regions, **kwargs)
+ return setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate
+
+ @pytest.mark.parametrize("write_operation, error, exclude_regions", write_operations_errors_and_boolean(create_failover_errors()))
+ async def test_ppaf_partition_info_cache_and_routing_async(self, write_operation, error, exclude_regions):
+ # This test validates that the partition info cache is updated correctly upon failures, and that the
+ # per-partition automatic failover logic routes requests to the next available regional endpoint.
+ # We also verify that this logic is unaffected by user excluded regions, since write-region routing is
+ # entirely taken care of on the service.
+ error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error))
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, 1,
+ write_operation == BATCH, exclude_client_regions=exclude_regions)
+ container = setup['col']
+ fault_injection_container = custom_setup['col']
+ global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager
+
+ # Create a document to populate the per-partition GEM partition range info cache
+ await fault_injection_container.create_item(body={'id': doc_success_id, 'pk': PK_VALUE,
+ 'name': 'sample document', 'key': 'value'})
+ pk_range_wrapper = list(global_endpoint_manager.partition_range_to_failover_info.keys())[0]
+ initial_region = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper].current_region
+
+ # Based on our configuration, we should have had one error followed by a success - marking only the previous endpoint as unavailable
+ await perform_write_operation(
+ write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the partition is marked as unavailable, and that the current regional endpoint is not the same
+ assert len(partition_info.unavailable_regional_endpoints) == 1
+ assert initial_region in partition_info.unavailable_regional_endpoints
+ assert initial_region != partition_info.current_region # west us 3 != west us
+
+ # Now we run another request to see how the cache gets updated
+ await perform_write_operation(
+ write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the cache is empty, since the request going to the second regional endpoint failed
+ # Once we reach the point of all available regions being marked as unavailable, the cache is cleared
+ assert len(partition_info.unavailable_regional_endpoints) == 0
+ assert initial_region not in partition_info.unavailable_regional_endpoints
+ assert partition_info.current_region is None
+
+ @pytest.mark.parametrize("write_operation, error, exclude_regions", write_operations_errors_and_boolean(create_threshold_errors()))
+ async def test_ppaf_partition_thresholds_and_routing_async(self, write_operation, error, exclude_regions):
+ # This test validates that the partition info cache is updated correctly upon failures, and that the
+ # per-partition automatic failover logic routes requests to the next available regional endpoint.
+ # We also verify that this logic is unaffected by user excluded regions, since write-region routing is
+ # entirely taken care of on the service.
+ error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error))
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda,
+ exclude_client_regions=exclude_regions,)
+ container = setup['col']
+ fault_injection_container = custom_setup['col']
+ global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager
+
+ # Create a document to populate the per-partition GEM partition range info cache
+ await fault_injection_container.create_item(body={'id': doc_success_id, 'pk': PK_VALUE,
+ 'name': 'sample document', 'key': 'value'})
+ pk_range_wrapper = list(global_endpoint_manager.partition_range_to_failover_info.keys())[0]
+ initial_region = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper].current_region
+
+ consecutive_failures = 6
+ for i in range(consecutive_failures):
+ # We perform the write operation multiple times to check the consecutive failures logic
+ with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info:
+ await perform_write_operation(write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ assert exc_info.value == error
+
+ # Verify that the threshold for consecutive failures is updated
+ pk_range_wrappers = list(global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count.keys())
+ assert len(pk_range_wrappers) == 1
+ failure_count = global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count[pk_range_wrappers[0]]
+ assert failure_count == consecutive_failures
+
+ # Verify that a single success to the same partition resets the consecutive failures count
+ await perform_write_operation(write_operation,
+ container,
+ fault_injection_container,
+ str(uuid.uuid4()),
+ PK_VALUE)
+
+ failure_count = global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count.get(pk_range_wrappers[0], 0)
+ assert failure_count == 0
+
+ # Run enough failed requests to the partition to trigger the failover logic
+ for i in range(12):
+ with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info:
+ await perform_write_operation(write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ assert exc_info.value == error
+ # We should have marked the previous endpoint as unavailable after 10 successive failures
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the partition is marked as unavailable, and that the current regional endpoint is not the same
+ assert len(partition_info.unavailable_regional_endpoints) == 1
+ assert initial_region in partition_info.unavailable_regional_endpoints
+ assert initial_region != partition_info.current_region # west us 3 != west us
+
+ # 12 failures - 10 to trigger failover, 2 more to start counting again
+ failure_count = global_endpoint_manager.ppaf_thresholds_tracker.pk_range_wrapper_to_failure_count[pk_range_wrappers[0]]
+ assert failure_count == 2
+
+ @pytest.mark.parametrize("write_operation, error, exclude_regions", write_operations_errors_and_boolean(create_failover_errors()))
+ async def test_ppaf_session_unavailable_retry_async(self, write_operation, error, exclude_regions):
+ # Account config has 2 regions: West US 3 (A) and West US (B). This test validates that after marking the write
+ # region (A) as unavailable, the next request is retried to the read region (B) and succeeds. The next read request
+ # should see that the write region (A) is unavailable for the partition, and should retry to the read region (B) as well.
+ # We also verify that this logic is unaffected by user excluded regions, since write-region routing is
+ # entirely taken care of on the service.
+ error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error))
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, max_count=1,
+ is_batch=write_operation==BATCH,
+ session_error=True, exclude_client_regions=exclude_regions)
+ container = setup['col']
+ fault_injection_container = custom_setup['col']
+ global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager
+
+ # Create a document to populate the per-partition GEM partition range info cache
+ await fault_injection_container.create_item(body={'id': doc_success_id, 'pk': PK_VALUE,
+ 'name': 'sample document', 'key': 'value'})
+ pk_range_wrapper = list(global_endpoint_manager.partition_range_to_failover_info.keys())[0]
+ initial_region = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper].current_region
+
+ # Verify the region that is being used for the read requests
+ read_response = await fault_injection_container.read_item(doc_success_id, PK_VALUE)
+ uri = read_response.get_response_headers().get('Content-Location')
+ region = fault_injection_container.client_connection._global_endpoint_manager.location_cache.get_location_from_endpoint(uri)
+ assert region == REGION_1 # first preferred region
+
+ # Based on our configuration, we should have had one error followed by a success - marking only the previous endpoint as unavailable
+ await perform_write_operation(
+ write_operation,
+ container,
+ fault_injection_container,
+ doc_fail_id,
+ PK_VALUE)
+ partition_info = global_endpoint_manager.partition_range_to_failover_info[pk_range_wrapper]
+ # Verify that the partition is marked as unavailable, and that the current regional endpoint is not the same
+ assert len(partition_info.unavailable_regional_endpoints) == 1
+ assert initial_region in partition_info.unavailable_regional_endpoints
+ assert initial_region != partition_info.current_region # west us 3 != west us
+
+ # Now we run a read request that runs into a 404.1002 error, which should retry to the read region
+ # We verify that the read request was going to the correct region by using the raw_response_hook
+ fault_injection_container.read_item(doc_fail_id, PK_VALUE, raw_response_hook=session_retry_hook)
+
+ async def test_ppaf_user_agent_feature_flag_async(self):
+ # Simple test to verify the user agent suffix is being updated with the relevant feature flags
+ setup, doc_fail_id, doc_success_id, custom_setup, custom_transport, predicate = await self.setup_info()
+ fault_injection_container = custom_setup['col']
+ # Create a document to check the response headers
+ await fault_injection_container.upsert_item(body={'id': doc_success_id, 'pk': PK_VALUE, 'name': 'sample document', 'key': 'value'},
+ raw_response_hook=ppaf_user_agent_hook)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py
index 1f0a0884bd7b..700e2112621b 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py
@@ -50,9 +50,9 @@ def read_operations_and_errors():
return params
-def write_operations_and_errors():
+def write_operations_and_errors(error_list=None):
write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH]
- errors = create_errors()
+ errors = error_list or create_errors()
params = []
for write_operation in write_operations:
for error in errors:
@@ -60,6 +60,17 @@ def write_operations_and_errors():
return params
+def write_operations_errors_and_boolean(error_list=None):
+ write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH]
+ errors = error_list or create_errors()
+ params = []
+ for write_operation in write_operations:
+ for error in errors:
+ for boolean in [True, False]:
+ params.append((write_operation, error, boolean))
+
+ return params
+
def operations():
write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH]
read_operations = [READ, QUERY_PK, CHANGE_FEED_PK, CHANGE_FEED_EPK]
@@ -69,9 +80,9 @@ def operations():
return operations
-def create_errors():
+def create_errors(errors=None):
errors = []
- error_codes = [408, 500, 502, 503]
+ error_codes = [408, 500, 502, 504]
for error_code in error_codes:
errors.append(CosmosHttpResponseError(
status_code=error_code,
@@ -97,7 +108,8 @@ def validate_unhealthy_partitions(global_endpoint_manager,
def validate_response_uri(response, expected_uri):
request = response.get_response_headers()["_request"]
assert request.url.startswith(expected_uri)
-def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri):
+
+def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri=None):
doc = {'id': doc_id,
'pk': pk,
'name': 'sample document',
@@ -107,7 +119,7 @@ def perform_write_operation(operation, container, fault_injection_container, doc
elif operation == UPSERT:
resp = fault_injection_container.upsert_item(body=doc)
elif operation == REPLACE:
- container.create_item(body=doc)
+ container.upsert_item(body=doc)
sleep(1)
new_doc = {'id': doc_id,
'pk': pk,
@@ -115,11 +127,11 @@ def perform_write_operation(operation, container, fault_injection_container, doc
'key': 'value'}
resp = fault_injection_container.replace_item(item=doc['id'], body=new_doc)
elif operation == DELETE:
- container.create_item(body=doc)
+ container.upsert_item(body=doc)
sleep(1)
resp = fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk'])
elif operation == PATCH:
- container.create_item(body=doc)
+ container.upsert_item(body=doc)
sleep(1)
operations = [{"op": "incr", "path": "/company", "value": 3}]
resp = fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations)
@@ -133,9 +145,9 @@ def perform_write_operation(operation, container, fault_injection_container, doc
resp = fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk'])
# this will need to be emulator only
elif operation == DELETE_ALL_ITEMS_BY_PARTITION_KEY:
- container.create_item(body=doc)
+ container.upsert_item(body=doc)
resp = fault_injection_container.delete_all_items_by_partition_key(pk)
- if resp:
+ if resp and expected_uri:
validate_response_uri(resp, expected_uri)
def perform_read_operation(operation, container, doc_id, pk, expected_uri):
@@ -395,10 +407,11 @@ def setup_info(self, error, **kwargs):
return container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate
def test_stat_reset(self):
+ status_code = 500
error_lambda = lambda r: FaultInjectionTransport.error_after_delay(
0,
CosmosHttpResponseError(
- status_code=503,
+ status_code=status_code,
message="Some injected error.")
)
container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = \
@@ -425,7 +438,7 @@ def test_stat_reset(self):
PK_VALUE,
expected_uri)
except CosmosHttpResponseError as e:
- assert e.status_code == 503
+ assert e.status_code == status_code
validate_unhealthy_partitions(global_endpoint_manager, 0)
validate_stats(global_endpoint_manager, 2, 2, 2, 2, 0, 0)
sleep(25)
@@ -487,6 +500,14 @@ def test_service_request_error(self, read_operation, write_operation):
# there shouldn't be region marked as unavailable
assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1
+ def test_circuit_breaker_user_agent_feature_flag_mm(self):
+ # Simple test to verify the user agent suffix is being updated with the relevant feature flags
+ custom_setup = self.setup_method_with_custom_transport(None)
+ container = custom_setup['col']
+ # Create a document to check the response headers
+ container.upsert_item(body={'id': str(uuid.uuid4()), 'pk': PK_VALUE, 'name': 'sample document', 'key': 'value'},
+ raw_response_hook=user_agent_hook)
+
# test cosmos client timeout
if __name__ == '__main__':
@@ -509,3 +530,8 @@ def validate_stats(global_endpoint_manager,
assert health_info.write_failure_count == expected_write_failure_count
assert health_info.read_success_count == expected_read_success_count
assert health_info.write_success_count == expected_write_success_count
+
+def user_agent_hook(raw_response):
+ # Used to verify the user agent feature flags
+ user_agent = raw_response.http_request.headers.get('user-agent')
+ assert user_agent.endswith('| F2')
\ No newline at end of file
diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py
index 7aef745545b5..90131646c17a 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py
@@ -4,7 +4,7 @@
import os
import unittest
import uuid
-from typing import Any
+from typing import Any, Union
import pytest
from azure.core.pipeline.transport._aiohttp import AioHttpTransport
@@ -18,12 +18,12 @@
from test_per_partition_circuit_breaker_mm import create_doc, read_operations_and_errors, \
write_operations_and_errors, operations, REGION_1, REGION_2, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ, \
CREATE, READ_ALL_ITEMS, DELETE_ALL_ITEMS_BY_PARTITION_KEY, QUERY, QUERY_PK, BATCH, UPSERT, REPLACE, PATCH, DELETE, \
- PK_VALUE, validate_unhealthy_partitions, validate_response_uri
+ PK_VALUE, validate_unhealthy_partitions, validate_response_uri, user_agent_hook
from test_per_partition_circuit_breaker_mm import validate_stats
COLLECTION = "created_collection"
-async def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri):
+async def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri=None):
doc = {'id': doc_id,
'pk': pk,
'name': 'sample document',
@@ -33,7 +33,7 @@ async def perform_write_operation(operation, container, fault_injection_containe
elif operation == UPSERT:
resp = await fault_injection_container.upsert_item(body=doc)
elif operation == REPLACE:
- await container.create_item(body=doc)
+ await container.upsert_item(body=doc)
new_doc = {'id': doc_id,
'pk': pk,
'name': 'sample document' + str(uuid),
@@ -41,11 +41,11 @@ async def perform_write_operation(operation, container, fault_injection_containe
await asyncio.sleep(1)
resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc)
elif operation == DELETE:
- await container.create_item(body=doc)
+ await container.upsert_item(body=doc)
await asyncio.sleep(1)
resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk'])
elif operation == PATCH:
- await container.create_item(body=doc)
+ await container.upsert_item(body=doc)
await asyncio.sleep(1)
operations = [{"op": "incr", "path": "/company", "value": 3}]
resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations)
@@ -59,9 +59,9 @@ async def perform_write_operation(operation, container, fault_injection_containe
resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk'])
# this will need to be emulator only
elif operation == DELETE_ALL_ITEMS_BY_PARTITION_KEY:
- await container.create_item(body=doc)
+ await container.upsert_item(body=doc)
resp = await fault_injection_container.delete_all_items_by_partition_key(pk)
- if resp:
+ if resp and expected_uri:
validate_response_uri(resp, expected_uri)
async def perform_read_operation(operation, container, doc_id, pk, expected_uri):
@@ -111,7 +111,7 @@ class TestPerPartitionCircuitBreakerMMAsync:
TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
TEST_CONTAINER_MULTI_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID
- async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs):
+ async def setup_method_with_custom_transport(self, custom_transport: Union[AioHttpTransport, Any], default_endpoint=host, **kwargs):
container_id = kwargs.pop("container_id", None)
if not container_id:
container_id = self.TEST_CONTAINER_MULTI_PARTITION_ID
@@ -339,10 +339,11 @@ async def test_read_failure_rate_threshold_async(self, read_operation, error):
await cleanup_method([custom_setup, setup])
async def test_stat_reset_async(self):
+ status_code = 500
error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(
0,
CosmosHttpResponseError(
- status_code=503,
+ status_code=status_code,
message="Some injected error.")
))
setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = \
@@ -371,7 +372,7 @@ async def test_stat_reset_async(self):
PK_VALUE,
expected_uri)
except CosmosHttpResponseError as e:
- assert e.status_code == 503
+ assert e.status_code == status_code
validate_unhealthy_partitions(global_endpoint_manager, 0)
validate_stats(global_endpoint_manager, 2, 2, 2, 2, 0, 0)
await asyncio.sleep(25)
@@ -480,5 +481,13 @@ async def concurrent_upsert():
_partition_health_tracker.INITIAL_UNAVAILABLE_TIME_MS = original_unavailable_time
await cleanup_method([custom_setup, setup])
+ async def test_circuit_breaker_user_agent_feature_flag_mm_async(self):
+ # Simple test to verify the user agent suffix is being updated with the relevant feature flags
+ custom_setup = await self.setup_method_with_custom_transport(None)
+ container = custom_setup['col']
+ # Create a document to check the response headers
+ await container.upsert_item(body={'id': str(uuid.uuid4()), 'pk': PK_VALUE, 'name': 'sample document', 'key': 'value'},
+ raw_response_hook=user_agent_hook)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py
index d8f25b2c50b7..0ec3df11d270 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py
@@ -15,7 +15,7 @@
from azure.cosmos.exceptions import CosmosHttpResponseError
from _fault_injection_transport import FaultInjectionTransport
from test_per_partition_circuit_breaker_mm import create_doc, write_operations_and_errors, operations, REGION_1, \
- REGION_2, PK_VALUE, perform_write_operation, perform_read_operation, CREATE, READ, validate_stats
+ REGION_2, PK_VALUE, perform_write_operation, perform_read_operation, CREATE, READ, validate_stats, user_agent_hook
COLLECTION = "created_collection"
@@ -71,10 +71,11 @@ def setup_info(self, error, **kwargs):
return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate
def test_stat_reset(self):
+ status_code = 500
error_lambda = lambda r: FaultInjectionTransport.error_after_delay(
0,
CosmosHttpResponseError(
- status_code=503,
+ status_code=status_code,
message="Some injected error.")
)
setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = \
@@ -103,7 +104,7 @@ def test_stat_reset(self):
PK_VALUE,
expected_uri)
except CosmosHttpResponseError as e:
- assert e.status_code == 503
+ assert e.status_code == status_code
validate_unhealthy_partitions(global_endpoint_manager, 0)
validate_stats(global_endpoint_manager, 0, 2, 2, 0, 0, 0)
sleep(25)
@@ -234,6 +235,14 @@ def test_service_request_error(self, read_operation, write_operation):
# there shouldn't be region marked as unavailable
assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1
+ def test_circuit_breaker_user_agent_feature_flag_sm(self):
+ # Simple test to verify the user agent suffix is being updated with the relevant feature flags
+ custom_setup = self.setup_method_with_custom_transport(None)
+ container = custom_setup['col']
+ # Create a document to check the response headers
+ container.upsert_item(body={'id': str(uuid.uuid4()), 'pk': PK_VALUE, 'name': 'sample document', 'key': 'value'},
+ raw_response_hook=user_agent_hook)
+
# test cosmos client timeout
if __name__ == '__main__':
diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py
index 8e66d2c827ea..2d43fb492b8c 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py
@@ -4,7 +4,7 @@
import os
import unittest
import uuid
-from typing import Dict, Any
+from typing import Any, Union
import pytest
from azure.core.pipeline.transport._aiohttp import AioHttpTransport
@@ -17,7 +17,7 @@
from _fault_injection_transport_async import FaultInjectionTransportAsync
from test_per_partition_circuit_breaker_mm_async import perform_write_operation, cleanup_method, perform_read_operation
from test_per_partition_circuit_breaker_mm import create_doc, write_operations_and_errors, operations, REGION_1, \
- REGION_2, PK_VALUE, READ, validate_stats, CREATE
+ REGION_2, PK_VALUE, READ, validate_stats, CREATE, user_agent_hook
from test_per_partition_circuit_breaker_sm_mrr import validate_unhealthy_partitions
COLLECTION = "created_collection"
@@ -31,7 +31,7 @@ class TestPerPartitionCircuitBreakerSmMrrAsync:
TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
TEST_CONTAINER_MULTI_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID
- async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs):
+ async def setup_method_with_custom_transport(self, custom_transport: Union[AioHttpTransport, Any], default_endpoint=host, **kwargs):
container_id = kwargs.pop("container_id", None)
if not container_id:
container_id = self.TEST_CONTAINER_MULTI_PARTITION_ID
@@ -43,7 +43,7 @@ async def setup_method_with_custom_transport(self, custom_transport: AioHttpTran
return {"client": client, "db": db, "col": container}
@staticmethod
- async def cleanup_method(initialized_objects: Dict[str, Any]):
+ async def cleanup_method(initialized_objects: dict[str, Any]):
method_client: CosmosClient = initialized_objects["client"]
await method_client.close()
@@ -135,10 +135,11 @@ async def test_write_failure_rate_threshold_async(self, write_operation, error):
await cleanup_method([custom_setup, setup])
async def test_stat_reset_async(self):
+ status_code = 500
error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(
0,
CosmosHttpResponseError(
- status_code=503,
+ status_code=status_code,
message="Some injected error.")
))
setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = \
@@ -167,7 +168,7 @@ async def test_stat_reset_async(self):
PK_VALUE,
expected_uri)
except CosmosHttpResponseError as e:
- assert e.status_code == 503
+ assert e.status_code == status_code
validate_unhealthy_partitions(global_endpoint_manager, 0)
validate_stats(global_endpoint_manager, 0, 2, 2, 0, 0, 0)
await asyncio.sleep(25)
@@ -233,6 +234,14 @@ async def test_service_request_error_async(self, read_operation, write_operation
assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1
await cleanup_method([custom_setup, setup])
+ async def test_circuit_breaker_user_agent_feature_flag_sm_async(self):
+ # Simple test to verify the user agent suffix is being updated with the relevant feature flags
+ custom_setup = await self.setup_method_with_custom_transport(None)
+ container = custom_setup['col']
+ # Create a document to check the response headers
+ await container.upsert_item(body={'id': str(uuid.uuid4()), 'pk': PK_VALUE, 'name': 'sample document', 'key': 'value'},
+ raw_response_hook=user_agent_hook)
+
# test cosmos client timeout
if __name__ == '__main__':
diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py
index eb3082fdc976..7e2235b66b33 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py
@@ -10,7 +10,7 @@
from azure.core.exceptions import ServiceRequestError, ServiceResponseError
import test_config
-from azure.cosmos import DatabaseAccount, _location_cache
+from azure.cosmos import DatabaseAccount
from azure.cosmos._location_cache import RegionalRoutingContext
from azure.cosmos._request_object import RequestObject
from azure.cosmos.aio import CosmosClient, _retry_utility_async, _global_endpoint_manager_async
diff --git a/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py
index dfe1044535e4..585f58a85978 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py
@@ -79,9 +79,10 @@ def test_timeout_failover_retry_policy_for_read_failure(self, setup, error_code)
created_document = setup[COLLECTION].create_item(body=document_definition)
self.original_execute_function = _retry_utility.ExecuteFunction
+ num_exceptions = max(2, len(setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.read_regional_routing_contexts))
try:
- # should retry once and then fail
- mf = self.MockExecuteFunction(self.original_execute_function, 2, error_code)
+ # should retry and then fail
+ mf = self.MockExecuteFunction(self.original_execute_function, num_exceptions, error_code)
_retry_utility.ExecuteFunction = mf
setup[COLLECTION].read_item(item=created_document['id'],
partition_key=created_document['pk'])
@@ -99,9 +100,11 @@ def test_timeout_failover_retry_policy_for_write_failure(self, setup, error_code
'key': 'value'}
self.original_execute_function = _retry_utility.ExecuteFunction
+ num_exceptions_503 = max(2, len(setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.write_regional_routing_contexts))
try:
- # timeouts should fail immediately for writes
- mf = self.MockExecuteFunction(self.original_execute_function,0, error_code)
+ # timeouts should fail immediately for writes - except for 503s, which should retry on every preferred location
+ num_exceptions = num_exceptions_503 if error_code == 503 else 0
+ mf = self.MockExecuteFunction(self.original_execute_function, num_exceptions, error_code)
_retry_utility.ExecuteFunction = mf
try:
setup[COLLECTION].create_item(body=document_definition)
diff --git a/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy_async.py
index f3a12546f467..e21ed080b913 100644
--- a/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy_async.py
+++ b/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy_async.py
@@ -79,9 +79,10 @@ async def test_timeout_failover_retry_policy_for_read_failure_async(self, setup,
created_document = await setup[COLLECTION].create_item(body=document_definition)
self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync
+ num_exceptions = max(2, len(setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.read_regional_routing_contexts))
try:
- # should retry once and then succeed
- mf = self.MockExecuteFunction(self.original_execute_function, 2, error_code)
+ # should retry and then succeed
+ mf = self.MockExecuteFunction(self.original_execute_function, num_exceptions, error_code)
_retry_utility_async.ExecuteFunctionAsync = mf
await setup[COLLECTION].read_item(item=created_document['id'],
partition_key=created_document['pk'])
@@ -131,9 +132,11 @@ async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup
'key': 'value'}
self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync
+ num_exceptions_503 = max(2, len(setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.write_regional_routing_contexts))
try:
- # timeouts should fail immediately for writes
- mf = self.MockExecuteFunction(self.original_execute_function,0, error_code)
+ # timeouts should fail immediately for writes - except for 503s, which should retry on every preferred location
+ num_exceptions = num_exceptions_503 if error_code == 503 else 0
+ mf = self.MockExecuteFunction(self.original_execute_function, num_exceptions, error_code)
_retry_utility_async.ExecuteFunctionAsync = mf
try:
await setup[COLLECTION].create_item(body=document_definition)
diff --git a/sdk/cosmos/azure-cosmos/tests/workloads/r_w_q_workload.py b/sdk/cosmos/azure-cosmos/tests/workloads/r_w_q_workload.py
index 5e1db6425142..0d7730c6f86e 100644
--- a/sdk/cosmos/azure-cosmos/tests/workloads/r_w_q_workload.py
+++ b/sdk/cosmos/azure-cosmos/tests/workloads/r_w_q_workload.py
@@ -3,6 +3,9 @@
import sys
from azure.cosmos import documents
+from datetime import datetime, timezone
+import time
+from workload_utils import _get_upsert_item
from workload_utils import *
from workload_configs import *
sys.path.append(r"/")
@@ -10,7 +13,27 @@
from azure.cosmos.aio import CosmosClient as AsyncClient
import asyncio
+async def log_request_counts(counter):
+ while True:
+ await asyncio.sleep(300) # 5 minutes
+ count = counter["count"]
+ duration = counter["upsert_time"] + counter["read_time"]
+ print("Current UTC time:", datetime.now(timezone.utc))
+ print(f"Executed {count} requests in the last 5 minutes")
+ print(f"Errors in the last 5 minutes: {counter['error_count']}")
+ print(f"Per-request latency: {duration / count if count > 0 else 0} ms")
+ print(f"Upsert latency: {counter['upsert_time'] / (count / 2) if count > 0 else 0} ms")
+ print(f"Read latency: {counter['read_time'] / (count / 2) if count > 0 else 0} ms")
+ print("-------------------------------")
+ counter["count"] = 0 # reset for next interval
+ counter["upsert_time"] = 0
+ counter["read_time"] = 0
+ counter["error_count"] = 0
+
async def run_workload(client_id, client_logger):
+ counter = {"count": 0, "upsert_time": 0, "read_time": 0, "error_count": 0}
+ # Start background task
+ asyncio.create_task(log_request_counts(counter))
connectionPolicy = documents.ConnectionPolicy()
connectionPolicy.UseMultipleWriteLocations = USE_MULTIPLE_WRITABLE_LOCATIONS
async with AsyncClient(COSMOS_URI, COSMOS_CREDENTIAL, connection_policy=connectionPolicy,
@@ -23,15 +46,32 @@ async def run_workload(client_id, client_logger):
while True:
try:
- await upsert_item_concurrently(cont, REQUEST_EXCLUDED_LOCATIONS, CONCURRENT_REQUESTS)
- await read_item_concurrently(cont, REQUEST_EXCLUDED_LOCATIONS, CONCURRENT_REQUESTS)
- await query_items_concurrently(cont, REQUEST_EXCLUDED_LOCATIONS, CONCURRENT_QUERIES)
+ upsert_start = time.perf_counter()
+ up_item = _get_upsert_item()
+ await cont.upsert_item(up_item)
+ elapsed = time.perf_counter() - upsert_start
+ counter["count"] += 1
+ counter["upsert_time"] += elapsed
+
+ read_start = time.perf_counter()
+ item = get_existing_random_item()
+ await cont.read_item(item["id"], item[PARTITION_KEY])
+ elapsed = time.perf_counter() - read_start
+ counter["count"] += 1
+ counter["read_time"] += elapsed
+
+ # await upsert_item_concurrently(cont, REQUEST_EXCLUDED_LOCATIONS, CONCURRENT_REQUESTS)
+ # await read_item_concurrently(cont, REQUEST_EXCLUDED_LOCATIONS, CONCURRENT_REQUESTS)
+ # await query_items_concurrently(cont, REQUEST_EXCLUDED_LOCATIONS, CONCURRENT_QUERIES)
except Exception as e:
+ counter["error_count"] += 1
client_logger.info("Exception in application layer")
- client_logger.error(e)
if __name__ == "__main__":
file_name = os.path.basename(__file__)
prefix, logger = create_logger(file_name)
+ create_inner_logger()
+ utc_now = datetime.now(timezone.utc)
+ print("Current UTC time:", utc_now)
asyncio.run(run_workload(prefix, logger))
diff --git a/sdk/cosmos/azure-cosmos/tests/workloads/workload_utils.py b/sdk/cosmos/azure-cosmos/tests/workloads/workload_utils.py
index 6a0f95128e5d..fe3d69b3bfbe 100644
--- a/sdk/cosmos/azure-cosmos/tests/workloads/workload_utils.py
+++ b/sdk/cosmos/azure-cosmos/tests/workloads/workload_utils.py
@@ -3,6 +3,7 @@
import asyncio
import os
import random
+import sys
import uuid
from datetime import datetime
from logging.handlers import RotatingFileHandler
@@ -160,15 +161,27 @@ def create_logger(file_name):
handler = RotatingFileHandler(
"log-" + get_user_agent(prefix) + '.log',
maxBytes=1024 * 1024 * 10, # 10 mb
- backupCount=2
+ backupCount=5
)
logger.setLevel(LOG_LEVEL)
# create filters for the logger handler to reduce the noise
workload_logger_filter = WorkloadLoggerFilter()
- handler.addFilter(workload_logger_filter)
+ # handler.addFilter(workload_logger_filter)
logger.addHandler(handler)
return prefix, logger
+def create_inner_logger(file_name="internal_logger_tues"):
+ logger = logging.getLogger("internal_requests")
+ prefix = os.path.splitext(file_name)[0] + "-" + str(os.getpid())
+ # Create a rotating file handler
+ handler = RotatingFileHandler(
+ "log-" + file_name + '.log',
+ maxBytes=1024 * 1024 * 10, # 10 mb
+ backupCount=5
+ )
+ logger.setLevel(LOG_LEVEL)
+ logger.addHandler(handler)
+
class WorkloadLoggerFilter(logging.Filter):
def filter(self, record):
diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json
index dc9216653246..108af0029374 100644
--- a/sdk/cosmos/live-platform-matrix.json
+++ b/sdk/cosmos/live-platform-matrix.json
@@ -59,6 +59,23 @@
}
}
},
+ {
+ "PerPartitionAutomaticFailoverTestConfig": {
+ "Ubuntu2004_313_partition_automatic_failover": {
+ "OSVmImage": "env:LINUXVMIMAGE",
+ "Pool": "env:LINUXPOOL",
+ "PythonVersion": "3.13",
+ "CoverageArg": "--disablecov",
+ "TestSamples": "false",
+ "TestMarkArgument": "cosmosPerPartitionAutomaticFailover"
+ }
+ },
+ "ArmConfig": {
+ "MultiRegion": {
+ "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true;}"
+ }
+ }
+ },
{
"MacTestConfig": {
"macos311_search_query": {