diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 6deaab9d60a6..ac91577e49b2 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -5,8 +5,10 @@ #### Features Added * Added ability to set a user agent suffix at the client level. See [PR 40904](https://github.com/Azure/azure-sdk-for-python/pull/40904) * Added ability to use request level `excluded_locations` on metadata calls, such as getting container properties. See [PR 40905](https://github.com/Azure/azure-sdk-for-python/pull/40905) +* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). #### Bugs Fixed +* Fixed how resource tokens are parsed for metadata calls in the lifecycle of a document operation. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). * Fixed issue where Query Change Feed did not return items if the container uses legacy Hash V1 Partition Keys. This also fixes issues with not being able to change feed query for Specific Partition Key Values for HPK. See [PR 41270](https://github.com/Azure/azure-sdk-for-python/pull/41270/) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 58dbb3486c40..5b7fd0254421 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -877,8 +877,8 @@ def _format_batch_operations( return final_operations -def _set_properties_cache(properties: Dict[str, Any]) -> Dict[str, Any]: +def _build_properties_cache(properties: Dict[str, Any], container_link: str) -> Dict[str, Any]: return { "_self": properties.get("_self", None), "_rid": properties.get("_rid", None), - "partitionKey": properties.get("partitionKey", None) + "partitionKey": properties.get("partitionKey", None), "container_link": container_link } diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index cb82d5bd7de2..d0e0f54ae04c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -51,6 +51,16 @@ class _Constants: HS_MAX_ITEMS_CONFIG_DEFAULT: int = 1000 MAX_ITEM_BUFFER_VS_CONFIG: str = "AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH" MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000 + CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" + CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" + # Only applicable when circuit breaker is enabled ------------------------- + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10 + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT: int = 5 + FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED" + FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 90 + # ------------------------------------------------------------------------- # Error code translations ERROR_TRANSLATIONS: Dict[int, str] = { diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py index a85e3c081ddf..53ee57b8c3f8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py @@ -72,10 +72,8 @@ def ShouldRetry(self, exception: Optional[Any]) -> bool: def __find_container_link_with_rid(self, container_properties_caches: Optional[Dict[str, Any]], rid: str) -> \ Optional[str]: if container_properties_caches: - for key, inner_dict in container_properties_caches.items(): - is_match = next((k for k, v in inner_dict.items() if v == rid), None) - if is_match: - return key + if rid in container_properties_caches: + return container_properties_caches[rid]["container_link"] # If we cannot get the container link at all it might mean the cache was somehow deleted, this isn't # a container request so this retry is not needed. Return None. return None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index c085977a5ccd..3ee11eeccdd8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -48,7 +48,7 @@ HttpResponse # pylint: disable=no-legacy-azure-core-http-response-import from . import _base as base -from . import _global_endpoint_manager as global_endpoint_manager +from ._global_partition_endpoint_manager_circuit_breaker import _GlobalPartitionEndpointManagerForCircuitBreaker from . import _query_iterable as query_iterable from . import _runtime_constants as runtime_constants from . import _session @@ -57,7 +57,7 @@ from . import documents from . import http_constants, exceptions from ._auth_policy import CosmosBearerTokenCredentialPolicy -from ._base import _set_properties_cache +from ._base import _build_properties_cache from ._change_feed.change_feed_iterable import ChangeFeedIterable from ._change_feed.change_feed_state import ChangeFeedState from ._constants import _Constants as Constants @@ -168,7 +168,7 @@ def __init__( # pylint: disable=too-many-statements self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy): @@ -262,6 +262,7 @@ def _set_container_properties_cache(self, container_link: str, properties: Optio :type properties: Optional[Dict[str, Any]]""" if properties: self.__container_properties_cache[container_link] = properties + self.__container_properties_cache[properties["_rid"]] = properties else: self.__container_properties_cache[container_link] = {} @@ -1295,8 +1296,13 @@ def CreateItem( if base.IsItemContainerLink(database_or_container_link): options = self._AddPartitionKey(database_or_container_link, document, options) - return self.Create(document, path, http_constants.ResourceType.Document, collection_id, None, - options, **kwargs) + return self.Create(document, + path, + http_constants.ResourceType.Document, + collection_id, + None, + options, + **kwargs) def UpsertItem( self, @@ -1332,8 +1338,13 @@ def UpsertItem( collection_id, document, path = self._GetContainerIdWithPathForItem( database_or_container_link, document, options ) - return self.Upsert(document, path, http_constants.ResourceType.Document, collection_id, None, - options, **kwargs) + return self.Upsert(document, + path, + http_constants.ResourceType.Document, + collection_id, + None, + options, + **kwargs) PartitionResolverErrorMessage = ( "Couldn't find any partition resolvers for the database link provided. " @@ -2020,8 +2031,13 @@ def ReplaceItem( collection_link = base.GetItemContainerLink(document_link) options = self._AddPartitionKey(collection_link, new_document, options) - return self.Replace(new_document, path, http_constants.ResourceType.Document, document_id, None, - options, **kwargs) + return self.Replace(new_document, + path, + http_constants.ResourceType.Document, + document_id, + None, + options, + **kwargs) def PatchItem( self, @@ -2052,7 +2068,9 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(resource_type, documents._OperationType.Patch) + request_params = RequestObject(resource_type, + documents._OperationType.Patch, + headers) request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): @@ -2142,7 +2160,9 @@ def _Batch( headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, http_constants.ResourceType.Document, documents._OperationType.Batch, options) - request_params = RequestObject(http_constants.ResourceType.Document, documents._OperationType.Batch) + request_params = RequestObject(http_constants.ResourceType.Document, + documents._OperationType.Batch, + headers) request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], @@ -2203,7 +2223,9 @@ def DeleteAllItemsByPartitionKey( collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options) - request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete) + request_params = RequestObject(http_constants.ResourceType.PartitionKey, + documents._OperationType.Delete, + headers) request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, @@ -2377,7 +2399,7 @@ def ExecuteStoredProcedure( # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation request_params = RequestObject(http_constants.ResourceType.StoredProcedure, - documents._OperationType.ExecuteJavaScript) + documents._OperationType.ExecuteJavaScript, headers) result, self.last_response_headers = self.__Post(path, request_params, params, headers, **kwargs) return result @@ -2573,7 +2595,9 @@ def GetDatabaseAccount( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject(http_constants.ResourceType.DatabaseAccount, documents._OperationType.Read, + request_params = RequestObject(http_constants.ResourceType.DatabaseAccount, + documents._OperationType.Read, + headers, url_connection) result, last_response_headers = self.__Get("", request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2623,7 +2647,9 @@ def _GetDatabaseAccountCheck( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject(http_constants.ResourceType.DatabaseAccount, documents._OperationType.Read, + request_params = RequestObject(http_constants.ResourceType.DatabaseAccount, + documents._OperationType.Read, + headers, url_connection) self.__Get("", request_params, headers, **kwargs) @@ -2663,7 +2689,7 @@ def Create( options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create) + request_params = RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2710,7 +2736,7 @@ def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params = RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2754,7 +2780,7 @@ def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace) + request_params = RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2796,7 +2822,7 @@ def Read( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = RequestObject(typ, documents._OperationType.Read) + request_params = RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2836,7 +2862,7 @@ def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = RequestObject(typ, documents._OperationType.Delete) + request_params = RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3069,12 +3095,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - request_params = RequestObject( - resource_type, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed - ) - request_params.set_excluded_location_from_options(options) headers = base.GetHeaders( self, initial_headers, @@ -3082,11 +3104,18 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: path, resource_id, resource_type, - request_params.operation_type, + op_typ, options, partition_key_range_id ) + request_params = RequestObject( + resource_type, + op_typ, + headers + ) + request_params.set_excluded_location_from_options(options) + change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: feed_options = {} @@ -3115,8 +3144,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = RequestObject(resource_type, documents._OperationType.SqlQuery) - request_params.set_excluded_location_from_options(options) req_headers = base.GetHeaders( self, initial_headers, @@ -3129,6 +3156,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: partition_key_range_id ) + request_params = RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) + # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) if isPrefixPartitionQuery and "partitionKeyDefinition" in kwargs: @@ -3364,7 +3394,7 @@ def _refresh_container_properties_cache(self, container_link: str): # If container properties cache is stale, refresh it by reading the container. container = self.ReadContainer(container_link, options=None) # Only cache Container Properties that will not change in the lifetime of the container - self._set_container_properties_cache(container_link, _set_properties_cache(container)) + self._set_container_properties_cache(container_link, _build_properties_cache(container, container_link)) def _UpdateSessionIfRequired( self, @@ -3407,5 +3437,5 @@ def _get_partition_key_definition( else: container = self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") - self.__container_properties_cache[collection_link] = _set_properties_cache(container) + self._set_container_properties_cache(collection_link, _build_properties_cache(container, collection_link)) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 135a076a80f0..c4612b629c53 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -30,8 +30,10 @@ from . import _constants as constants from . import exceptions +from ._request_object import RequestObject from .documents import DatabaseAccount from ._location_cache import LocationCache +from ._utils import current_time_millis # pylint: disable=protected-access @@ -45,7 +47,6 @@ class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attrib def __init__(self, client): self.client = client - self.EnableEndpointDiscovery = client.connection_policy.EnableEndpointDiscovery self.PreferredLocations = client.connection_policy.PreferredLocations self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() @@ -67,7 +68,10 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def _resolve_service_endpoint( + self, + request: RequestObject + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): @@ -93,7 +97,7 @@ def update_location_cache(self): self.location_cache.update_location_cache() def refresh_endpoint_list(self, database_account, **kwargs): - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: with self.refresh_lock: @@ -109,11 +113,11 @@ def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() # this will perform getDatabaseAccount calls to check endpoint health self._endpoints_health_check(**kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py new file mode 100644 index 000000000000..2eda20c926d0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -0,0 +1,121 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +from typing import TYPE_CHECKING, Optional + +from azure.cosmos.partition_key import PartitionKey +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore + +from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager +from azure.cosmos._request_object import RequestObject +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range +from azure.cosmos.http_constants import HttpHeaders + +if TYPE_CHECKING: + from azure.cosmos._cosmos_client_connection import CosmosClientConnection + +class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) + + + def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionKeyRangeWrapper]: + if HttpHeaders.IntendedCollectionRID in request.headers: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + else: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain container information. " + "Circuit breaker cannot be performed.") + return None + properties = self.client._container_properties_cache[container_rid] # pylint: disable=protected-access + # get relevant information from container cache to get the overlapping ranges + container_link = properties["container_link"] + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"], + version=partition_key_definition["version"]) + + if HttpHeaders.PartitionKey in request.headers: + partition_key_value = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] # pylint: disable=protected-access + partition_ranges = (self.client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + elif HttpHeaders.PartitionKeyRangeID in request.headers: + pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + epk_range =(self.client._routing_map_provider # pylint: disable=protected-access + .get_range_by_partition_key_range_id(container_link, pk_range_id)) + if not epk_range: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: partition key range cache not initialized correctly. " + "Circuit breaker cannot be performed.") + return None + partition_range = Range.PartitionKeyRangeToRange(epk_range) + else: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain partition information. " + "Circuit breaker cannot be performed.") + return None + + return PartitionKeyRangeWrapper(partition_range, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + + def resolve_service_endpoint_for_partition( + self, + request: RequestObject, + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] + ) -> str: + if self.is_circuit_breaker_applicable(request) and pk_range_wrapper: + self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return self._resolve_service_endpoint(request) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py new file mode 100644 index 000000000000..93faf9b7a8c5 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -0,0 +1,121 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import _PartitionHealthTracker +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._location_cache import EndpointOperationType, LocationCache +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos._constants import _Constants as Constants + +logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForCircuitBreakerCore") +WARN_LEVEL_LOGGING_THRESHOLD = 10 + +class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + def __init__(self, client, location_cache: LocationCache): + self.partition_health_tracker = _PartitionHealthTracker() + self.location_cache = location_cache + self.client = client + self.log_count = 0 + + def log_warn_or_debug(self, message: str) -> None: + self.log_count += 1 + if self.log_count >= WARN_LEVEL_LOGGING_THRESHOLD: + logger.debug(message) + else: + logger.warning(message) + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): # pylint: disable=protected-access + return False + + if (request.resource_type not in (ResourceType.Document, ResourceType.PartitionKey) + or request.operation_type == documents._OperationType.QueryPlan): # pylint: disable=protected-access + return False + + # this is for certain cross partition queries and read all items where we cannot discern partition information + if (HttpHeaders.PartitionKeyRangeID not in request.headers + and HttpHeaders.PartitionKey not in request.headers): + return False + + return True + + def record_failure( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + #convert operation_type to EndpointOperationType + endpoint_operation_type = (EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) # pylint: disable=protected-access + else EndpointOperationType.ReadType) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_failure(pk_range_wrapper, endpoint_operation_type, str(location)) + + def check_stale_partition_info( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.partition_health_tracker.check_stale_partition_info(request, pk_range_wrapper) + + + def add_excluded_locations_to_request( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> RequestObject: + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_unhealthy_locations(request, pk_range_wrapper) + ) + return request + + def record_success( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType # pylint: disable=protected-access + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_success(pk_range_wrapper, endpoint_operation_type, location) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index ccb1528048be..90578c63e5dd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -24,7 +24,6 @@ """ import collections import logging -import time from typing import Set, Mapping, List from urllib.parse import urlparse @@ -152,8 +151,6 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: List[Re return applicable_regional_routing_contexts class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes - def current_time_millis(self): - return int(round(time.time() * 1000)) def __init__( self, @@ -180,6 +177,11 @@ def get_write_regional_routing_contexts(self): def get_read_regional_routing_contexts(self): return self.read_regional_routing_contexts + def get_location_from_endpoint(self, endpoint: str) -> str: + if endpoint in self.account_locations_by_read_endpoints: + return self.account_locations_by_read_endpoints[endpoint] + return self.account_write_locations[0] + def get_write_regional_routing_context(self): return self.get_write_regional_routing_contexts()[0].get_primary() @@ -209,8 +211,15 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> List[str # If excluded locations were configured on request, use request level excluded locations. excluded_locations = request.excluded_locations if excluded_locations is None: - # If excluded locations were only configured on client(connection_policy), use client level - excluded_locations = self.connection_policy.ExcludedLocations + if self.connection_policy.ExcludedLocations: + # If excluded locations were only configured on client(connection_policy), use client level + # make copy of excluded locations to avoid modifying the original list + excluded_locations = list(self.connection_policy.ExcludedLocations) + else: + excluded_locations = [] + for excluded_location in request.excluded_locations_circuit_breaker: + if excluded_location not in excluded_locations: + excluded_locations.append(excluded_location) return excluded_locations def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) -> List[RegionalRoutingContext]: @@ -434,7 +443,6 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl EndpointOperationType.ReadType, self.write_regional_routing_contexts[0] ) - self.last_cache_update_timestamp = self.current_time_millis() # pylint: disable=attribute-defined-outside-init def get_preferred_regional_routing_contexts( self, endpoints_by_location, orderedLocations, expected_available_operation, fallback_endpoint diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py new file mode 100644 index 000000000000..d1b2307c9093 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -0,0 +1,293 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for partition health tracker for circuit breaker. +""" +import logging +import threading +import os +from typing import Dict, Any, List +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._location_cache import EndpointOperationType +from azure.cosmos._request_object import RequestObject +from ._utils import current_time_millis +from ._constants import _Constants as Constants + +MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 +MAX_UNAVAILABLE_TIME = 1200 * 1000 # milliseconds +REFRESH_INTERVAL = 60 * 1000 # milliseconds +INITIAL_UNAVAILABLE_TIME = 60 * 1000 # milliseconds +# partition is unhealthy if sdk tried to recover and failed +UNHEALTHY = "unhealthy" +# partition is unhealthy tentative when it initially marked unavailable +UNHEALTHY_TENTATIVE = "unhealthy_tentative" +# unavailability info keys +UNAVAILABLE_INTERVAL = "unavailableInterval" +LAST_UNAVAILABILITY_CHECK_TIME_STAMP = "lastUnavailabilityCheckTimeStamp" +HEALTH_STATUS = "healthStatus" + +class _PartitionHealthInfo(object): + """ + This internal class keeps the health and statistics for a partition. + """ + + def __init__(self) -> None: + self.write_failure_count: int = 0 + self.read_failure_count: int = 0 + self.write_success_count: int = 0 + self.read_success_count: int = 0 + self.read_consecutive_failure_count: int = 0 + self.write_consecutive_failure_count: int = 0 + self.unavailability_info: Dict[str, Any] = {} + + def reset_health_stats(self) -> None: + self.write_failure_count = 0 + self.read_failure_count = 0 + self.write_success_count = 0 + self.read_success_count = 0 + self.read_consecutive_failure_count = 0 + self.write_consecutive_failure_count = 0 + + def transition_health_status(self, target_health_status: str, curr_time: int) -> None: + if target_health_status == UNHEALTHY : + self.unavailability_info[HEALTH_STATUS] = UNHEALTHY + # reset the last unavailability check time stamp + self.unavailability_info[UNAVAILABLE_INTERVAL] = \ + min(self.unavailability_info[UNAVAILABLE_INTERVAL] * 2, + MAX_UNAVAILABLE_TIME) + self.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ + = curr_time + elif target_health_status == UNHEALTHY_TENTATIVE : + self.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: curr_time, + UNAVAILABLE_INTERVAL: INITIAL_UNAVAILABLE_TIME, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + + def __str__(self) -> str: + return (f"{self.__class__.__name__}: {self.unavailability_info}\n" + f"write failure count: {self.write_failure_count}\n" + f"read failure count: {self.read_failure_count}\n" + f"write success count: {self.write_success_count}\n" + f"read success count: {self.read_success_count}\n" + f"write consecutive failure count: {self.write_consecutive_failure_count}\n" + f"read consecutive failure count: {self.read_consecutive_failure_count}\n") + +def _has_exceeded_failure_rate_threshold( + successes: int, + failures: int, + failure_rate_threshold: int, +) -> bool: + if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: + return False + failure_rate = failures / (failures + successes) * 100 + return failure_rate >= failure_rate_threshold + +def _should_mark_healthy_tentative(partition_health_info: _PartitionHealthInfo, curr_time: int) -> bool: + elapsed_time = (curr_time - + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) + current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + stale_partition_unavailability_check = partition_health_info.unavailability_info[UNAVAILABLE_INTERVAL] + # check if the partition key range is still unavailable + return ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) + or (current_health_status == UNHEALTHY_TENTATIVE and elapsed_time > INITIAL_UNAVAILABLE_TIME)) + +logger = logging.getLogger("azure.cosmos._PartitionHealthTracker") + +class _PartitionHealthTracker(object): + """ + This internal class implements the logic for tracking health thresholds for a partition. + """ + + def __init__(self) -> None: + # partition -> regions -> health info + self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} + self.last_refresh = current_time_millis() + self.stale_partition_lock = threading.Lock() + + def _transition_health_status_on_failure( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + logger.warning("%s has been marked as unavailable.", pk_range_wrapper) + current_time = current_time_millis() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + # healthy -> unhealthy tentative + partition_health_info = _PartitionHealthInfo() + partition_health_info.transition_health_status(UNHEALTHY_TENTATIVE, current_time) + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = { + location: partition_health_info + } + else: + region_to_partition_health = self.pk_range_wrapper_to_health_info[pk_range_wrapper] + if location in region_to_partition_health and region_to_partition_health[location].unavailability_info: + # healthy tentative -> unhealthy + region_to_partition_health[location].transition_health_status(UNHEALTHY, current_time) + # if the operation type is not empty, we are in the healthy tentative state + else: + # healthy -> unhealthy tentative + # if the operation type is empty, we are in the unhealthy tentative state + partition_health_info = _PartitionHealthInfo() + partition_health_info.transition_health_status(UNHEALTHY_TENTATIVE, current_time) + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = partition_health_info + + def _transition_health_status_on_success( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + # healthy tentative -> healthy + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location].unavailability_info = {} + + def check_stale_partition_info( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + current_time = current_time_millis() + + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if partition_health_info.unavailability_info: + if _should_mark_healthy_tentative(partition_health_info, current_time): + # unhealthy or unhealthy tentative -> healthy tentative + # only one request should be used to recover + with self.stale_partition_lock: + if _should_mark_healthy_tentative(partition_health_info, current_time): + # this will trigger one attempt to recover + partition_health_info.transition_health_status(UNHEALTHY, current_time) + request.healthy_tentative_location = location + + if current_time - self.last_refresh > REFRESH_INTERVAL: + # all partition stats reset every minute + self._reset_partition_health_tracker_stats() + + + def get_unhealthy_locations( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> List[str]: + excluded_locations = [] + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if (partition_health_info.unavailability_info and + not (request.healthy_tentative_location and request.healthy_tentative_location == location)): + health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY) : + excluded_locations.append(location) + return excluded_locations + + def add_failure( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + operation_type: str, + location: str + ) -> None: + # Retrieve the failure rate threshold from the environment. + failure_rate_threshold = int(os.environ.get(Constants.FAILURE_PERCENTAGE_TOLERATED, + Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) + + # Ensure that the health info dictionary is properly initialized. + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] + + # Determine attribute names and environment variables based on the operation type. + if operation_type == EndpointOperationType.WriteType: + success_attr = 'write_success_count' + failure_attr = 'write_failure_count' + consecutive_attr = 'write_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT + else: + success_attr = 'read_success_count' + failure_attr = 'read_failure_count' + consecutive_attr = 'read_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT + + # Increment failure and consecutive failure counts. + setattr(health_info, failure_attr, getattr(health_info, failure_attr) + 1) + setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) + + # Retrieve the consecutive failure threshold from the environment. + consecutive_failure_threshold = int(os.environ.get(env_key, default_consecutive_threshold)) + + # Call the threshold checker with the current stats. + self._check_thresholds( + pk_range_wrapper, + getattr(health_info, success_attr), + getattr(health_info, failure_attr), + getattr(health_info, consecutive_attr), + location, + failure_rate_threshold, + consecutive_failure_threshold + ) + + def _check_thresholds( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + successes: int, + failures: int, + consecutive_failures: int, + location: str, + failure_rate_threshold: int, + consecutive_failure_threshold: int, + ) -> None: + # check the failure rate was not exceeded + if _has_exceeded_failure_rate_threshold( + successes, + failures, + failure_rate_threshold + ): + self._transition_health_status_on_failure(pk_range_wrapper, location) + + # add to consecutive failures and check that threshold was not exceeded + if consecutive_failures >= consecutive_failure_threshold: + self._transition_health_status_on_failure(pk_range_wrapper, location) + + def add_success(self, pk_range_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + # Ensure that the health info dictionary is initialized. + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] + + if operation_type == EndpointOperationType.WriteType: + health_info.write_success_count += 1 + health_info.write_consecutive_failure_count = 0 + else: + health_info.read_success_count += 1 + health_info.read_consecutive_failure_count = 0 + self._transition_health_status_on_success(pk_range_wrapper, location) + + def _reset_partition_health_tracker_stats(self) -> None: + for locations in self.pk_range_wrapper_to_health_info.values(): + for health_info in locations.values(): + health_info.reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 2277aa4d0157..d20eedb40148 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,19 +21,28 @@ """Represents a request object. """ -from typing import Optional, Mapping, Any +from typing import Optional, Mapping, Any, Dict, List -class RequestObject(object): - def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: +class RequestObject(object): # pylint: disable=too-many-instance-attributes + def __init__( + self, + resource_type: str, + operation_type: str, + headers: Dict[str, Any], + endpoint_override: Optional[str] = None, + ) -> None: self.resource_type = resource_type self.operation_type = operation_type self.endpoint_override = endpoint_override self.should_clear_session_token_on_session_read_failure: bool = False # pylint: disable=name-too-long + self.headers = headers self.use_preferred_locations: Optional[bool] = None self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None - self.excluded_locations = None + self.excluded_locations: Optional[List[str]] = None + self.excluded_locations_circuit_breaker: List[str] = [] + self.healthy_tentative_location: Optional[str] = None def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -70,3 +79,6 @@ def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None: if self._can_set_excluded_location(options): self.excluded_locations = options['excludedLocations'] + + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: # pylint: disable=name-too-long + self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index d9ca17b6a80c..91145ef217ba 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -45,8 +45,11 @@ # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches - -def Execute(client, global_endpoint_manager, function, *args, **kwargs): +# args [0] is the request object +# args [1] is the connection policy +# args [2] is the pipeline client +# args [3] is the http request +def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylint: disable=too-many-locals """Executes the function with passed parameters applying all retry policies :param object client: @@ -59,6 +62,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -74,19 +80,19 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -105,6 +111,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): try: if args: result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs) + global_endpoint_manager.record_success(args[0]) else: result = ExecuteFunction(function, *args, **kwargs) if not client.last_response_headers: @@ -173,9 +180,10 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + if args: + # record the failure for circuit breaker tracking + global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -208,6 +216,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): if not database_account_retry_policy.ShouldRetry(e): raise e else: + if args: + global_endpoint_manager.record_failure(args[0]) _handle_service_request_retries(client, service_request_retry_policy, e, *args) except ServiceResponseError as e: @@ -215,6 +225,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): if not database_account_retry_policy.ShouldRetry(e): raise e else: + if args: + global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) def ExecuteFunction(function, *args, **kwargs): @@ -236,7 +248,12 @@ def _has_database_account_header(request_headers): return True return False -def _handle_service_request_retries(client, request_retry_policy, exception, *args): +def _handle_service_request_retries( + client, + request_retry_policy, + exception, + *args +): # we resolve the request endpoint to the next preferred region # once we are out of preferred regions we stop retrying retry_policy = request_retry_policy @@ -292,7 +309,8 @@ def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -317,7 +335,8 @@ def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies # This logic is based on the _retry.py file from azure-core - if not _has_database_account_header(request.http_request.headers): + if (not _has_database_account_header(request.http_request.headers) + and not request_params.healthy_tentative_location): if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -328,11 +347,14 @@ def send(self, request): retry_error = err # Only read operations can be safely retried with ServiceResponseError if (not _has_read_retryable_headers(request.http_request.headers) or - _has_database_account_header(request.http_request.headers)): + _has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err - # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: + # record the failure for circuit breaker tracking for retries in connection retry policy + # retries in the execute function will mark those failures + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) @@ -342,9 +364,11 @@ def send(self, request): raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index 5af0f082af2f..88b740be7290 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -22,6 +22,7 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +from typing import Dict, Any, Optional from ... import _base from ..collection_routing_map import CollectionRoutingMap @@ -59,14 +60,24 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, fe :return: List of overlapping partition key ranges. :rtype: list """ - cl = self._documentClient - collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.init_collection_routing_map_if_needed(collection_link, str(collection_id), feed_options, **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) + async def init_collection_routing_map_if_needed( + self, + collection_link: str, + collection_id: str, + feed_options: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any] + ): collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: collection_pk_ranges = [pk async for pk in - cl._ReadPartitionKeyRanges(collection_link, feed_options, **kwargs)] + self._documentClient._ReadPartitionKeyRanges(collection_link, + feed_options, + **kwargs)] # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. @@ -75,7 +86,18 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, fe [(r, True) for r in collection_pk_ranges], collection_id ) self._collection_routing_map_by_item[collection_id] = collection_routing_map - return collection_routing_map.get_overlapping_ranges(partition_key_ranges) + + async def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id( + partition_key_range_id) @staticmethod def _discard_parent_ranges(partitionKeyRanges): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 901cf9f20899..544397db326d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -22,6 +22,7 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +from typing import Dict, Any, Optional from .. import _base from .collection_routing_map import CollectionRoutingMap @@ -50,6 +51,27 @@ def __init__(self, client): # keeps the cached collection routing map by collection id self._collection_routing_map_by_item = {} + def init_collection_routing_map_if_needed( + self, + collection_link: str, + collection_id: str, + feed_options: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any] + ): + collection_routing_map = self._collection_routing_map_by_item.get(collection_id) + if not collection_routing_map: + collection_pk_ranges = list(self._documentClient._ReadPartitionKeyRanges(collection_link, + feed_options, + **kwargs)) + # for large collections, a split may complete between the read partition key ranges query page responses, + # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need + # to discard the parent ranges to have a valid routing map. + collection_pk_ranges = PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges) + collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( + [(r, True) for r in collection_pk_ranges], collection_id + ) + self._collection_routing_map_by_item[collection_id] = collection_routing_map + def get_overlapping_ranges(self, collection_link, partition_key_ranges, feed_options = None, **kwargs): """Given a partition key range and a collection, return the list of overlapping partition key ranges. @@ -60,22 +82,22 @@ def get_overlapping_ranges(self, collection_link, partition_key_ranges, feed_opt :return: List of overlapping partition key ranges. :rtype: list """ - cl = self._documentClient + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + self.init_collection_routing_map_if_needed(collection_link, str(collection_id), feed_options, **kwargs) + return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) + + def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) - collection_routing_map = self._collection_routing_map_by_item.get(collection_id) - if collection_routing_map is None: - collection_pk_ranges = list(cl._ReadPartitionKeyRanges(collection_link, feed_options, **kwargs)) - # for large collections, a split may complete between the read partition key ranges query page responses, - # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need - # to discard the parent ranges to have a valid routing map. - collection_pk_ranges = PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges) - collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( - [(r, True) for r in collection_pk_ranges], collection_id - ) - self._collection_routing_map_by_item[collection_id] = collection_routing_map - return collection_routing_map.get_overlapping_ranges(partition_key_ranges) + return (self._collection_routing_map_by_item[collection_id] + .get_range_by_partition_key_range_id(partition_key_range_id)) @staticmethod def _discard_parent_ranges(partitionKeyRanges): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index bed80dda3764..1f75754f67db 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -226,3 +226,27 @@ def is_subset(self, parent_range: 'Range') -> bool: normalized_child_range = self.to_normalized_range() return (normalized_parent_range.min <= normalized_child_range.min and normalized_parent_range.max >= normalized_child_range.max) + +class PartitionKeyRangeWrapper(object): + """Internal class for a representation of a unique partition for an account + """ + + def __init__(self, partition_key_range: Range, collection_rid: str) -> None: + self.partition_key_range = partition_key_range + self.collection_rid = collection_rid + + + def __str__(self) -> str: + return ( + f"PartitionKeyRangeWrapper(" + f"partition_key_range={self.partition_key_range}, " + f"collection_rid={self.collection_rid}, " + ) + + def __eq__(self, other): + if not isinstance(other, PartitionKeyRangeWrapper): + return False + return self.partition_key_range == other.partition_key_range and self.collection_rid == other.collection_rid + + def __hash__(self): + return hash((self.partition_key_range, self.collection_rid)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 030774beff95..5b4faf75df84 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -13,9 +13,10 @@ class ServiceRequestRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.total_in_region_retries = 1 self.in_region_retry_count = 0 @@ -65,7 +66,7 @@ def ShouldRetry(self): self.failover_retry_count += 1 if self.failover_retry_count >= self.total_retries: return False - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region location_endpoint = self.resolve_next_region_service_endpoint() else: location_endpoint = self.resolve_current_region_service_endpoint() @@ -80,7 +81,7 @@ def ShouldRetry(self): # and we reset the in region retry count self.in_region_retry_count = 0 self.failover_retry_count += 1 - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region if self.failover_retry_count >= self.total_retries: return False location_endpoint = self.resolve_next_region_service_endpoint() @@ -96,7 +97,7 @@ def resolve_current_region_service_endpoint(self): # resolve the next service endpoint in the same region # since we maintain 2 endpoints per region for write operations self.request.route_to_location_with_preferred_location_flag(0, True) - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) # This function prepares the request to go to the next region def resolve_next_region_service_endpoint(self): @@ -110,7 +111,7 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(0, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) def mark_endpoint_unavailable(self, unavailable_endpoint, refresh_cache: bool): if _OperationType.IsReadOnlyOperation(self.request.operation_type): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py index 83a856f39d33..59fca57e1c76 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py @@ -12,15 +12,17 @@ class ServiceResponseRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.failover_retry_count = 0 self.connection_policy = connection_policy self.request = args[0] if args else None if self.request: - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, pk_range_wrapper)) self.logger = logging.getLogger('azure.cosmos.ServiceResponseRetryPolicy') def ShouldRetry(self): @@ -57,4 +59,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py index 1614f337de5b..69b9d52f286d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py @@ -41,10 +41,11 @@ class _SessionRetryPolicy(object): Max_retry_attempt_count = 1 Retry_after_in_milliseconds = 0 - def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): + def __init__(self, endpoint_discovery_enable, global_endpoint_manager, pk_range_wrapper, *args): self.global_endpoint_manager = global_endpoint_manager self._max_retry_attempt_count = _SessionRetryPolicy.Max_retry_attempt_count self.session_token_retry_count = 0 + self.pk_range_wrapper = pk_range_wrapper self.retry_after_in_milliseconds = _SessionRetryPolicy.Retry_after_in_milliseconds self.endpoint_discovery_enable = endpoint_discovery_enable self.request = args[0] if args else None @@ -57,7 +58,8 @@ def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)) self.request.route_to_location(self.location_endpoint) def ShouldRetry(self, _exception): @@ -98,7 +100,8 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)) self.request.route_to_location(self.location_endpoint) return True @@ -113,6 +116,7 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)) self.request.route_to_location(self.location_endpoint) return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 68e37caf1d9d..e41881429b20 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -65,7 +65,7 @@ def _request_body_from_data(data): return None -def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -90,6 +90,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin # Every request tries to perform a refresh client_timeout = kwargs.get('timeout') start_time = time.time() + if request_params.healthy_tentative_location: + read_timeout = connection_policy.RecoveryReadTimeout if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: @@ -104,7 +106,11 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint_for_partition(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) @@ -132,6 +138,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -142,6 +150,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index b170fb4fd9d2..b77ce1a69f13 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -9,11 +9,10 @@ class _TimeoutFailoverRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.retry_after_in_milliseconds = 500 - self.args = args - self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper # If an account only has 1 region, then we still want to retry once on the same region self._max_retry_attempt_count = (len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) + 1) @@ -56,4 +55,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py index aca8a6fb0913..f1899415af87 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py @@ -26,6 +26,7 @@ import re import base64 import json +import time from typing import Any, Dict, Optional from ._version import VERSION @@ -39,7 +40,6 @@ def get_user_agent(suffix: Optional[str]) -> str: user_agent += f" {suffix}" return user_agent - def get_user_agent_async(suffix: Optional[str]) -> str: os_name = safe_user_agent_header(platform.platform()) python_version = safe_user_agent_header(platform.python_version()) @@ -73,3 +73,6 @@ def get_index_metrics_info(delimited_string: Optional[str]) -> Dict[str, Any]: return result except (json.JSONDecodeError, ValueError): return {} + +def current_time_millis() -> int: + return int(round(time.time() * 1000)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 81430d8df42c..79e674eaa31c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -34,7 +34,7 @@ from .._synchronized_request import _request_body_from_data, _replace_url_prefix -async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -59,6 +59,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p # Every request tries to perform a refresh client_timeout = kwargs.get('timeout') start_time = time.time() + if request_params.healthy_tentative_location: + read_timeout = connection_policy.RecoveryReadTimeout if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: await global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: @@ -73,7 +75,11 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint_for_partition(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) @@ -101,6 +107,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -111,6 +119,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 33cd804ab278..d167a3e469ed 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -40,7 +40,7 @@ _deserialize_throughput, _replace_throughput, GenerateGuidId, - _set_properties_cache + _build_properties_cache ) from .._change_feed.feed_range_internal import FeedRangeInternalEpk from .._cosmos_responses import CosmosDict, CosmosList @@ -94,15 +94,16 @@ def __init__( self._scripts: Optional[ScriptsProxy] = None if properties: self.client_connection._set_container_properties_cache(self.container_link, - _set_properties_cache(properties)) + _build_properties_cache(properties, + self.container_link)) def __repr__(self) -> str: return "".format(self.container_link)[:1024] - async def _get_properties_with_feed_options(self, feed_options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def _get_properties_with_options(self, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: kwargs = {} - if feed_options and "excludedLocations" in feed_options: - kwargs['excluded_locations'] = feed_options['excludedLocations'] + if options and "excludedLocations" in options: + kwargs['excluded_locations'] = options['excludedLocations'] return await self._get_properties(**kwargs) async def _get_properties(self, **kwargs: Any) -> Dict[str, Any]: @@ -150,7 +151,7 @@ async def _get_epk_range_for_partition_key( self, partition_key_value: PartitionKeyType, feed_options: Optional[Dict[str, Any]] = None) -> Range: - container_properties = await self._get_properties_with_feed_options(feed_options) + container_properties = await self._get_properties_with_options(feed_options) partition_key_definition = container_properties["partitionKey"] partition_key = PartitionKey( path=partition_key_definition["paths"], @@ -180,10 +181,6 @@ async def read( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. - :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations - in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. - If all preferred locations were excluded, primary/hub location will be used. - This excluded_location will override existing excluded_locations in client level. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Raised if the container couldn't be retrieved. This includes if the container does not exist. :returns: Dict representing the retrieved container. @@ -207,7 +204,8 @@ async def read( request_options["populateQuotaInfo"] = populate_quota_info container = await self.client_connection.ReadContainer(self.container_link, options=request_options, **kwargs) # Only cache Container Properties that will not change in the lifetime of the container - self.client_connection._set_container_properties_cache(self.container_link, _set_properties_cache(container)) # pylint: disable=protected-access, line-too-long + self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access + _build_properties_cache(container, self.container_link)) return container @distributed_trace_async @@ -288,8 +286,8 @@ async def create_item( request_options["disableAutomaticIdGeneration"] = not enable_automatic_id_generation if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.CreateItem( database_or_container_link=self.container_link, document=body, options=request_options, **kwargs @@ -363,8 +361,8 @@ async def read_item( if max_integrated_cache_staleness_in_ms is not None: validate_cache_staleness_value(max_integrated_cache_staleness_in_ms) request_options["maxIntegratedCacheStaleness"] = max_integrated_cache_staleness_in_ms - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return await self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) @@ -420,6 +418,7 @@ def read_all_items( response_hook.clear() if self.container_link in self.__get_client_container_caches(): feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + kwargs["containerProperties"] = self._get_properties_with_options items = self.client_connection.ReadItems( collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs @@ -524,9 +523,9 @@ def query_items( feed_options["populateIndexMetrics"] = populate_index_metrics if enable_scan_in_query is not None: feed_options["enableScanInQuery"] = enable_scan_in_query + kwargs["containerProperties"] = self._get_properties_with_options if partition_key is not None: feed_options["partitionKey"] = self._set_partition_key(partition_key) - kwargs["containerProperties"] = self._get_properties_with_feed_options else: feed_options["enableCrossPartitionQuery"] = True if max_integrated_cache_staleness_in_ms: @@ -772,7 +771,7 @@ def query_items_change_feed( # pylint: disable=unused-argument change_feed_state_context["continuation"] = feed_options.pop("continuation") feed_options["changeFeedStateContext"] = change_feed_state_context - feed_options["containerProperties"] = self._get_properties_with_feed_options(feed_options) + feed_options["containerProperties"] = self._get_properties_with_options(feed_options) response_hook = kwargs.pop("response_hook", None) if hasattr(response_hook, "clear"): @@ -854,8 +853,8 @@ async def upsert_item( kwargs["throughput_bucket"] = throughput_bucket request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.UpsertItem( database_or_container_link=self.container_link, @@ -937,8 +936,8 @@ async def replace_item( kwargs["throughput_bucket"] = throughput_bucket request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.ReplaceItem( document_link=item_link, new_document=body, options=request_options, **kwargs @@ -1021,8 +1020,8 @@ async def patch_item( request_options["partitionKey"] = await self._set_partition_key(partition_key) if filter_predicate is not None: request_options["filterPredicate"] = filter_predicate - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) result = await self.client_connection.PatchItem( @@ -1093,8 +1092,8 @@ async def delete_item( kwargs["throughput_bucket"] = throughput_bucket request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] document_link = self._get_document_link(item) await self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs) @@ -1354,8 +1353,8 @@ async def delete_all_items_by_partition_key( request_options = _build_options(kwargs) # regardless if partition key is valid we set it as invalid partition keys are set to a default empty value request_options["partitionKey"] = await self._set_partition_key(partition_key) - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] await self.client_connection.DeleteAllItemsByPartitionKey(collection_link=self.container_link, options=request_options, **kwargs) @@ -1422,8 +1421,8 @@ async def execute_item_batch( request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return await self.client_connection.Batch( collection_link=self.container_link, batch_operations=batch_operations, options=request_options, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 249033add988..cbcd3ccafba7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -48,9 +48,11 @@ DistributedTracingPolicy, ProxyPolicy) from azure.core.utils import CaseInsensitiveDict +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import ( + _GlobalPartitionEndpointManagerForCircuitBreakerAsync) from .. import _base as base -from .._base import _set_properties_cache +from .._base import _build_properties_cache from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState @@ -63,7 +65,6 @@ from .. import _runtime_constants as runtime_constants from .. import _request_object from . import _asynchronous_request as asynchronous_request -from . import _global_endpoint_manager_async as global_endpoint_manager_async from .._routing.aio.routing_map_provider import SmartRoutingMapProvider from ._retry_utility_async import _ConnectionRetryPolicy from .. import _session @@ -173,7 +174,7 @@ def __init__( # pylint: disable=too-many-statements # Keeps the latest response headers from the server. self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager_async._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreakerAsync(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, AsyncHTTPPolicy): @@ -262,6 +263,7 @@ def _set_container_properties_cache(self, container_link: str, properties: Optio :type properties: Optional[Dict[str, Any]]""" if properties: self.__container_properties_cache[container_link] = properties + self.__container_properties_cache[properties["_rid"]] = properties else: self.__container_properties_cache[container_link] = {} @@ -420,7 +422,8 @@ async def GetDatabaseAccount( client_id=self.client_id) # path # id # type request_params = _request_object.RequestObject(http_constants.ResourceType.DatabaseAccount, - documents._OperationType.Read, url_connection) + documents._OperationType.Read, + headers, url_connection) result, self.last_response_headers = await self.__Get("", request_params, headers, **kwargs) database_account = documents.DatabaseAccount() @@ -471,7 +474,9 @@ async def _GetDatabaseAccountCheck( client_id=self.client_id) # path # id # type request_params = _request_object.RequestObject(http_constants.ResourceType.DatabaseAccount, - documents._OperationType.Read, url_connection) + documents._OperationType.Read, + headers, + url_connection) await self.__Get("", request_params, headers, **kwargs) async def CreateDatabase( @@ -742,7 +747,8 @@ async def ExecuteStoredProcedure( # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(http_constants.ResourceType.StoredProcedure, - documents._OperationType.ExecuteJavaScript) + documents._OperationType.ExecuteJavaScript, headers) + request_params.set_excluded_location_from_options(options) result, self.last_response_headers = await self.__Post(path, request_params, params, headers, **kwargs) return result @@ -780,7 +786,7 @@ async def Create( documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -921,7 +927,7 @@ async def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1223,7 +1229,7 @@ async def Read( headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1486,7 +1492,7 @@ async def PatchItem( headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, typ, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): @@ -1592,7 +1598,7 @@ async def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1917,7 +1923,7 @@ async def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2033,7 +2039,7 @@ async def _Batch( http_constants.ResourceType.Document, documents._OperationType.Batch, options) request_params = _request_object.RequestObject(http_constants.ResourceType.Document, - documents._OperationType.Batch) + documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2277,6 +2283,9 @@ def QueryItems( collection_id = base.GetResourceIdOrFullNameFromLink(database_or_container_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: + await kwargs["containerProperties"](options) + new_options = dict(options) + new_options["containerRID"] = self.__container_properties_cache[database_or_container_link]["_rid"] return ( await self.__QueryFeed( path, @@ -2285,7 +2294,7 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca lambda r: r["Documents"], lambda _, b: b, query, - options, + new_options, response_hook=response_hook, **kwargs ), @@ -2888,16 +2897,23 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: return [] initial_headers = self.default_headers.copy() + cont_prop_func = kwargs.pop("containerProperties", None) + cont_prop = None + if cont_prop_func: + cont_prop = await cont_prop_func(options) # get properties with feed options + # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) + headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, op_typ, + options, partition_key_range_id) request_params = _request_object.RequestObject( typ, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed + op_typ, + headers ) request_params.set_excluded_location_from_options(options) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -2928,19 +2944,18 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) - request_params.set_excluded_location_from_options(options) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, + documents._OperationType.SqlQuery, options, partition_key_range_id) + request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) # check if query has prefix partition key - cont_prop = kwargs.pop("containerProperties", None) partition_key_value = options.get("partitionKey", None) is_prefix_partition_query = False partition_key_obj = None - if cont_prop: - properties = await cont_prop(options) # get properties with feed options - partition_key_definition = properties["partitionKey"] + if cont_prop and partition_key_value is not None: + partition_key_definition = cont_prop["partitionKey"] partition_key_obj = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"], version=partition_key_definition["version"]) @@ -2949,9 +2964,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if is_prefix_partition_query and partition_key_obj: # here get the overlapping ranges req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) - feedrangeEPK = partition_key_obj._get_epk_range_for_prefix_partition_key( + feed_range_epk = partition_key_obj._get_epk_range_for_prefix_partition_key( partition_key_value) # cspell:disable-line - over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, [feedrangeEPK], options) + over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, + [feed_range_epk], + options) results: Dict[str, Any] = {} # For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over # lapping physical partition. The EPK sub range will be one of four: @@ -2966,8 +2983,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range) # Since the range min and max are all Upper Cased string Hex Values, # we can compare the values lexicographically - EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feedrangeEPK.min), - range_max=min(single_range.max, feedrangeEPK.max), + EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feed_range_epk.min), + range_max=min(single_range.max, feed_range_epk.max), isMinInclusive=True, isMaxInclusive=False) if single_range.min == EPK_sub_range.min and EPK_sub_range.max == single_range.max: # The Epk Sub Range spans exactly one physical partition @@ -3221,7 +3238,7 @@ async def _refresh_container_properties_cache(self, container_link: str): # If container properties cache is stale, refresh it by reading the container. container = await self.ReadContainer(container_link, options=None) # Only cache Container Properties that will not change in the lifetime of the container - self._set_container_properties_cache(container_link, _set_properties_cache(container)) + self._set_container_properties_cache(container_link, _build_properties_cache(container, container_link)) async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, excluded_locations: Optional[str] = None, @@ -3302,7 +3319,8 @@ async def DeleteAllItemsByPartitionKey( initial_headers = dict(self.default_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) - request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) + request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete, + headers) request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) @@ -3324,5 +3342,5 @@ async def _get_partition_key_definition( else: container = await self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") - self.__container_properties_cache[collection_link] = _set_properties_cache(container) + self._set_container_properties_cache(collection_link, _build_properties_cache(container, collection_link)) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 201f44f43a78..2d0184468149 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -33,11 +33,12 @@ from .. import _constants as constants from .. import exceptions from .._location_cache import LocationCache - +from .._utils import current_time_millis +from .._request_object import RequestObject # pylint: disable=protected-access -logger = logging.getLogger("azure.cosmos.aio_GlobalEndpointManager") +logger = logging.getLogger("azure.cosmos.aio._GlobalEndpointManager") class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attributes """ @@ -47,7 +48,6 @@ class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attrib def __init__(self, client): self.client = client - self.EnableEndpointDiscovery = client.connection_policy.EnableEndpointDiscovery self.PreferredLocations = client.connection_policy.PreferredLocations self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() @@ -71,7 +71,10 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def _resolve_service_endpoint( + self, + request: RequestObject + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): @@ -83,6 +86,9 @@ def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache): def get_ordered_write_locations(self): return self.location_cache.get_ordered_write_locations() + def get_ordered_read_locations(self): + return self.location_cache.get_ordered_read_locations() + def can_use_multiple_write_locations(self, request): return self.location_cache.can_use_multiple_write_locations_for_request(request) @@ -101,7 +107,7 @@ async def refresh_endpoint_list(self, database_account, **kwargs): self.refresh_task = None except (Exception, asyncio.CancelledError) as exception: #pylint: disable=broad-exception-caught logger.exception("Health check task failed: %s", exception) #pylint: disable=do-not-use-logging-exception - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: async with self.refresh_lock: @@ -117,11 +123,11 @@ async def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account and not self.startup: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() if not self.startup: # this will perform getDatabaseAccount calls to check endpoint health # in background diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py new file mode 100644 index 000000000000..78e8b551ee7a --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -0,0 +1,122 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +from typing import TYPE_CHECKING, Optional + +from azure.cosmos import PartitionKey +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range + +from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import HttpHeaders + +if TYPE_CHECKING: + from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + + +# pylint: disable=protected-access +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).__init__(client) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) + + async def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionKeyRangeWrapper]: + if HttpHeaders.IntendedCollectionRID in request.headers: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + else: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain container information. " + "Circuit breaker cannot be performed.") + return None + properties = self.client._container_properties_cache[container_rid] + # get relevant information from container cache to get the overlapping ranges + container_link = properties["container_link"] + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"], + version=partition_key_definition["version"]) + + if HttpHeaders.PartitionKey in request.headers: + partition_key_value = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] + partition_ranges = await (self.client._routing_map_provider + .get_overlapping_ranges(container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + elif HttpHeaders.PartitionKeyRangeID in request.headers: + pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + epk_range = await (self.client._routing_map_provider + .get_range_by_partition_key_range_id(container_link, pk_range_id)) + if not epk_range: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: partition key range cache not initialized correctly. " + "Circuit breaker cannot be performed.") + return None + partition_range = Range.PartitionKeyRangeToRange(epk_range) + else: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain partition information. " + "Circuit breaker cannot be performed.") + return None + + return PartitionKeyRangeWrapper(partition_range, container_rid) + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) + + async def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + + def resolve_service_endpoint_for_partition( + self, + request: RequestObject, + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] + ): + if self.is_circuit_breaker_applicable(request) and pk_range_wrapper: + self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return self._resolve_service_endpoint(request) + + async def record_success( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index ec7441d48e1f..33b9c0785b38 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -46,6 +46,10 @@ # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches +# args [0] is the request object +# args [1] is the connection policy +# args [2] is the pipeline client +# args [3] is the http request async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwargs): # pylint: disable=too-many-locals """Executes the function with passed parameters applying all retry policies @@ -59,6 +63,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -74,17 +81,17 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -103,6 +110,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg try: if args: result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) + await global_endpoint_manager.record_success(args[0]) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) if not client.last_response_headers: @@ -171,9 +179,10 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + # record the failure for circuit breaker tracking + if args: + await global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -220,9 +229,13 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if isinstance(e.inner_exception, ClientConnectionError): _handle_service_request_retries(client, service_request_retry_policy, e, *args) else: + if args: + await global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) # in case customer is not using aiohttp except ImportError: + if args: + await global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) @@ -256,6 +269,8 @@ async def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -279,7 +294,8 @@ async def send(self, request): retry_error = err # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies - if not _has_database_account_header(request.http_request.headers): + if (not _has_database_account_header(request.http_request.headers) + and not request_params.healthy_tentative_location): if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -288,7 +304,8 @@ async def send(self, request): raise err except ServiceResponseError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err # Since this is ClientConnectionError, it is safe to be retried on both read and write requests try: @@ -299,6 +316,9 @@ async def send(self, request): or _has_read_retryable_headers(request.http_request.headers)): # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: + # record the failure for circuit breaker tracking for retries in connection retry policy + # retries in the execute function will mark those failures + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) @@ -310,7 +330,8 @@ async def send(self, request): raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py index 754a9c93cdf3..221109a35bab 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py @@ -124,8 +124,8 @@ def __get_authorization_token_using_resource_token(resource_tokens, path, resour # used for creating the auth header as the service will accept any token in this case path = urllib.parse.unquote(path) if not path and not resource_id_or_fullname: - for value in resource_tokens.values(): - return value + for resource_token in resource_tokens.values(): + return resource_token if resource_tokens.get(resource_id_or_fullname): return resource_tokens[resource_id_or_fullname] @@ -151,7 +151,9 @@ def __get_authorization_token_using_resource_token(resource_tokens, path, resour for i in range(len(path_parts), 1, -1): segment = path_parts[i - 1] sub_path = "/".join(path_parts[:i]) - if not segment in resource_types and sub_path in resource_tokens: - return resource_tokens[sub_path] + if not segment in resource_types: + for resource_path, resource_token in resource_tokens.items(): + if sub_path in resource_path: + return resource_tokens[resource_path] return None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index ec9205967795..5647a66a99f7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -37,7 +37,7 @@ _deserialize_throughput, _replace_throughput, GenerateGuidId, - _set_properties_cache + _build_properties_cache ) from ._change_feed.feed_range_internal import FeedRangeInternalEpk from ._cosmos_client_connection import CosmosClientConnection @@ -108,15 +108,16 @@ def __init__( self._scripts: Optional[ScriptsProxy] = None if properties: self.client_connection._set_container_properties_cache(self.container_link, - _set_properties_cache(properties)) + _build_properties_cache(properties, + self.container_link)) def __repr__(self) -> str: return "".format(self.container_link)[:1024] - def _get_properties_with_feed_options(self, feed_options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def _get_properties_with_options(self, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: kwargs = {} - if feed_options and "excludedLocations" in feed_options: - kwargs['excluded_locations'] = feed_options['excludedLocations'] + if options and "excludedLocations" in options: + kwargs['excluded_locations'] = options['excludedLocations'] return self._get_properties(**kwargs) def _get_properties(self, **kwargs: Any) -> Dict[str, Any]: @@ -213,7 +214,8 @@ def read( # pylint:disable=docstring-missing-param request_options["populateQuotaInfo"] = populate_quota_info container = self.client_connection.ReadContainer(self.container_link, options=request_options, **kwargs) # Only cache Container Properties that will not change in the lifetime of the container - self.client_connection._set_container_properties_cache(self.container_link, _set_properties_cache(container)) # pylint: disable=protected-access, line-too-long + self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access + _build_properties_cache(container, self.container_link)) return container @distributed_trace @@ -291,8 +293,8 @@ def read_item( # pylint:disable=docstring-missing-param if max_integrated_cache_staleness_in_ms is not None: validate_cache_staleness_value(max_integrated_cache_staleness_in_ms) request_options["maxIntegratedCacheStaleness"] = max_integrated_cache_staleness_in_ms - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) @distributed_trace @@ -355,8 +357,8 @@ def read_all_items( # pylint:disable=docstring-missing-param if response_hook and hasattr(response_hook, "clear"): response_hook.clear() - if self.container_link in self.__get_client_container_caches(): - feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(feed_options) + feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] items = self.client_connection.ReadItems( collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs) @@ -574,7 +576,7 @@ def query_items_change_feed( elif "start_time" in kwargs: change_feed_state_context["startTime"] = kwargs.pop("start_time") - container_properties = self._get_properties_with_feed_options(feed_options) + container_properties = self._get_properties_with_options(feed_options) if "partition_key" in kwargs: partition_key = kwargs.pop("partition_key") change_feed_state_context["partitionKey"] = self._set_partition_key(cast(PartitionKeyType, partition_key)) @@ -696,9 +698,9 @@ def query_items( # pylint:disable=docstring-missing-param feed_options["populateQueryMetrics"] = populate_query_metrics if populate_index_metrics is not None: feed_options["populateIndexMetrics"] = populate_index_metrics + properties = self._get_properties_with_options(feed_options) if partition_key is not None: partition_key_value = self._set_partition_key(partition_key) - properties = self._get_properties_with_feed_options(feed_options) if is_prefix_partition_key(properties, partition_key): kwargs["isPrefixPartitionQuery"] = True kwargs["partitionKeyDefinition"] = properties["partitionKey"] @@ -716,8 +718,7 @@ def query_items( # pylint:disable=docstring-missing-param feed_options["responseContinuationTokenLimitInKb"] = continuation_token_limit if response_hook and hasattr(response_hook, "clear"): response_hook.clear() - if self.container_link in self.__get_client_container_caches(): - feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] items = self.client_connection.QueryItems( database_or_container_link=self.container_link, query=query if parameters is None else {"query": query, "parameters": parameters}, @@ -811,10 +812,13 @@ def replace_item( # pylint:disable=docstring-missing-param ) request_options["populateQueryMetrics"] = populate_query_metrics - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.ReplaceItem( - document_link=item_link, new_document=body, options=request_options, **kwargs) + document_link=item_link, + new_document=body, + options=request_options, + **kwargs) return result @distributed_trace @@ -894,8 +898,8 @@ def upsert_item( # pylint:disable=docstring-missing-param DeprecationWarning, ) request_options["populateQueryMetrics"] = populate_query_metrics - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.UpsertItem( database_or_container_link=self.container_link, @@ -994,8 +998,8 @@ def create_item( # pylint:disable=docstring-missing-param request_options["populateQueryMetrics"] = populate_query_metrics if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.CreateItem( database_or_container_link=self.container_link, document=body, options=request_options, **kwargs) return result @@ -1080,11 +1084,13 @@ def patch_item( if filter_predicate is not None: request_options["filterPredicate"] = filter_predicate - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) result = self.client_connection.PatchItem( - document_link=item_link, operations=patch_operations, options=request_options, **kwargs) + document_link=item_link, + operations=patch_operations, + options=request_options, **kwargs) return result @distributed_trace @@ -1153,6 +1159,8 @@ def execute_item_batch( request_options = build_options(kwargs) request_options["partitionKey"] = self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True + container_properties = self._get_properties_with_options(request_options) + request_options["containerRID"] = container_properties["_rid"] return self.client_connection.Batch( collection_link=self.container_link, batch_operations=batch_operations, options=request_options, **kwargs) @@ -1231,8 +1239,8 @@ def delete_item( # pylint:disable=docstring-missing-param request_options["preTriggerInclude"] = pre_trigger_include if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] document_link = self._get_document_link(item) self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs) @@ -1511,8 +1519,8 @@ def delete_all_items_by_partition_key( request_options = build_options(kwargs) # regardless if partition key is valid we set it as invalid partition keys are set to a default empty value request_options["partitionKey"] = self._set_partition_key(partition_key) - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties_with_options(request_options) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] self.client_connection.DeleteAllItemsByPartitionKey( collection_link=self.container_link, options=request_options, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index b3c55a53eba3..a0e55077aefa 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -339,6 +339,7 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes __defaultRequestTimeout: int = 5 # seconds __defaultDBAConnectionTimeout: int = 3 # seconds __defaultReadTimeout: int = 65 # seconds + __defaultRecoveryReadTimeout: int = 6 # seconds __defaultDBAReadTimeout: int = 3 # seconds __defaultMaxBackoff: int = 1 # seconds @@ -347,6 +348,9 @@ def __init__(self) -> None: self.RequestTimeout: int = self.__defaultRequestTimeout self.DBAConnectionTimeout: int = self.__defaultDBAConnectionTimeout self.ReadTimeout: int = self.__defaultReadTimeout + # The request timeout for a request trying to recover a unavailable partition + # This is only applicable if circuit breaker is enabled + self.RecoveryReadTimeout: int = self.__defaultRecoveryReadTimeout self.DBAReadTimeout: int = self.__defaultDBAReadTimeout self.MaxBackoff: int = self.__defaultMaxBackoff self.ConnectionMode: int = ConnectionMode.Gateway diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 0ea65741e343..aabe78b51f08 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -4,3 +4,6 @@ markers = cosmosLong: marks tests to be run on a Cosmos DB live account. cosmosQuery: marks tests running queries on Cosmos DB live account. cosmosSplit: marks test where there are partition splits on CosmosDB live account. + cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled. + cosmosCircuitBreaker: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled. + cosmosCircuitBreakerMultiRegion: marks tests running on Cosmos DB live account with one write region and multiple read regions and per partition circuit breaker enabled. diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 1413a30e78cc..2386e54fd882 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -166,6 +166,12 @@ def predicate_is_operation_type(r: HttpRequest, operation_type: str) -> bool: return is_operation_type + @staticmethod + def predicate_is_resource_type(r: HttpRequest, resource_type: str) -> bool: + is_resource_type = r.headers.get(HttpHeaders.ThinClientProxyResourceType) == resource_type + + return is_resource_type + @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: is_write_document_operation = documents._OperationType.IsWriteOperation( @@ -229,7 +235,7 @@ def transform_topology_mwr( first_region_name: str, second_region_name: str, inner: Callable[[], RequestsTransportResponse], - first_region_url: str = None, + first_region_url: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1"), second_region_url: str = test_config.TestConfig.local_host ) -> RequestsTransportResponse: diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py index 17c9b7c1dce2..9a96d024d3e9 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py @@ -182,7 +182,7 @@ def validate_empty_query_ranges(self, smart_routing_map_provider, *queryRangesLi self.validate_overlapping_ranges_results(queryRanges, []) def get_overlapping_ranges(self, queryRanges): - return self.smart_routing_map_provider.get_overlapping_ranges("sample collection id", queryRanges) + return self.smart_routing_map_provider.get_overlapping_ranges("dbs/db/colls/container", queryRanges) if __name__ == "__main__": diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py index 415ba47a63a9..94713d543003 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import unittest import uuid from datetime import datetime, timedelta, timezone @@ -17,13 +17,17 @@ @pytest.fixture(scope="class") def setup(): config = test_config.TestConfig() + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True if (config.masterKey == '[YOUR_KEY_HERE]' or config.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - test_client = cosmos_client.CosmosClient(config.host, config.masterKey), + test_client = cosmos_client.CosmosClient(config.host, config.masterKey, + multiple_write_locations=use_multiple_write_locations), return { "created_db": test_client[0].get_database_client(config.TEST_DATABASE_ID), "is_emulator": config.is_emulator @@ -33,6 +37,7 @@ def round_time(): utc_now = datetime.now(timezone.utc) return utc_now - timedelta(microseconds=utc_now.microsecond) +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery @pytest.mark.unittest @pytest.mark.usefixtures("setup") diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py index b3a666ad43e9..1ab899a6bd47 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py @@ -3,6 +3,7 @@ import unittest import uuid +import os from asyncio import sleep from datetime import datetime, timedelta, timezone @@ -18,13 +19,16 @@ @pytest_asyncio.fixture() async def setup(): + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True config = test_config.TestConfig() if config.masterKey == '[YOUR_KEY_HERE]' or config.host == '[YOUR_ENDPOINT_HERE]': raise Exception( "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - test_client = CosmosClient(config.host, config.masterKey) + test_client = CosmosClient(config.host, config.masterKey, multiple_write_locations=use_multiple_write_locations) created_db = await test_client.create_database_if_not_exists(config.TEST_DATABASE_ID) created_db_data = { "created_db": created_db, diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py new file mode 100644 index 000000000000..30448b2aa623 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py @@ -0,0 +1,263 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import os +import unittest +import uuid + +import pytest +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import _partition_health_tracker, documents +from azure.cosmos import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos.http_constants import ResourceType +from test_per_partition_circuit_breaker_mm import perform_write_operation +from test_per_partition_circuit_breaker_mm_async import (create_doc, PK_VALUE, create_errors, + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + validate_unhealthy_partitions as validate_unhealthy_partitions_mm) +from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions as validate_unhealthy_partitions_sm_mrr + +COLLECTION = "created_collection" +@pytest.fixture(scope="class", autouse=True) +def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + yield + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +def create_custom_transport_mm(): + custom_transport = FaultInjectionTransport() + is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="Write Region", + second_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + return custom_transport + + +@pytest.mark.cosmosEmulator +@pytest.mark.usefixtures("setup_teardown") +class TestCircuitBreakerEmulator: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + + def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + def setup_info(self, error, mm=False): + expected_uri = self.host + uri_down = self.host.replace("localhost", "127.0.0.1") + custom_transport = create_custom_transport_mm() if mm else self.create_custom_transport_sm_mrr() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransport.predicate_is_resource_type(r, ResourceType.Collection) and + FaultInjectionTransport.predicate_is_operation_type(r, documents._OperationType.Delete) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host, multiple_write_locations=mm) + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host, multiple_write_locations=mm) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("error", create_errors()) + def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + perform_write_operation( + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + + + @pytest.mark.parametrize("error", create_errors()) + def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(4): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + # writes should now succeed because going to the other region + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + + + @pytest.mark.parametrize("error", create_errors()) + def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + @pytest.mark.parametrize("error", create_errors()) + def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py new file mode 100644 index 000000000000..23315499c2da --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py @@ -0,0 +1,269 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid + +import pytest +import pytest_asyncio +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import _partition_health_tracker, documents +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.cosmos.http_constants import ResourceType +from test_per_partition_circuit_breaker_mm_async import (create_doc, PK_VALUE, create_errors, + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + validate_unhealthy_partitions as validate_unhealthy_partitions_mm, + perform_write_operation, cleanup_method) +from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions as validate_unhealthy_partitions_sm_mrr +from _fault_injection_transport_async import FaultInjectionTransportAsync + +COLLECTION = "created_collection" +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + yield + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +async def create_custom_transport_mm(): + custom_transport = FaultInjectionTransportAsync() + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="Write Region", + second_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + return custom_transport + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_teardown") +class TestCircuitBreakerEmulatorAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + async def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + + async def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransportAsync() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransportAsync.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + async def setup_info(self, error, mm=False): + expected_uri = self.host + uri_down = self.host.replace("localhost", "127.0.0.1") + custom_transport = await create_custom_transport_mm() if mm else await self.create_custom_transport_sm_mrr() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_resource_type(r, ResourceType.Collection) and + FaultInjectionTransportAsync.predicate_is_operation_type(r, documents._OperationType.Delete) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, error) + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host, multiple_write_locations=mm) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host, multiple_write_locations=mm) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm_async(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await perform_write_operation( + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) + + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm_async(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(4): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + # writes should now succeed because going to the other region + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) + + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm_async(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + await fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm_async(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + await fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index 232fe9fcd5cd..bcd8df3b494c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -7,15 +7,14 @@ import unittest import uuid -from azure.cosmos._retry_utility import _has_database_account_header, _has_read_retryable_headers +from azure.cosmos._retry_utility import _has_database_account_header, _has_read_retryable_headers, _configure_timeout from azure.cosmos.cosmos_client import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import StatusCodes from azure.cosmos.partition_key import PartitionKey from azure.cosmos import (ContainerProxy, DatabaseProxy, documents, exceptions, http_constants, _retry_utility) -from azure.cosmos.aio import _retry_utility_async -from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError +from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError, ClientAuthenticationError from azure.core.pipeline.policies import AsyncRetryPolicy, RetryPolicy from devtools_testutils.azure_recorded_testcase import get_credential from devtools_testutils.helpers import is_live @@ -334,7 +333,10 @@ def __init__(self, resource_type, error=None, **kwargs): def send(self, request): self.counter = 0 absolute_timeout = request.context.options.pop('timeout', None) - + per_request_timeout = request.context.options.pop('connection_timeout', 0) + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) + retry_error = None retry_active = True response = None retry_settings = self.configure_retries(request.context.options) @@ -346,26 +348,44 @@ def send(self, request): self.request_endpoints.append(request.http_request.url) if self.error: raise self.error + _configure_timeout(request, absolute_timeout, per_request_timeout) response = self.next.send(request) break + except ClientAuthenticationError: # pylint:disable=try-except-raise + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise + except exceptions.CosmosClientTimeoutError as timeout_error: + timeout_error.inner_exception = retry_error + timeout_error.response = response + timeout_error.history = retry_settings['history'] + raise except ServiceRequestError as err: + retry_error = err # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies # This logic is based on the _retry.py file from azure-core - if retry_settings['connect'] > 0: - self.counter += 1 - retry_active = self.increment(retry_settings, response=request, error=err) - if retry_active: - self.sleep(retry_settings, request.context.transport) - continue + if (not _has_database_account_header(request.http_request.headers) + and not request_params.healthy_tentative_location): + if retry_settings['connect'] > 0: + self.counter += 1 + global_endpoint_manager.record_failure(request_params) + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + self.sleep(retry_settings, request.context.transport) + continue raise err except ServiceResponseError as err: + retry_error = err # Only read operations can be safely retried with ServiceResponseError - if not _retry_utility._has_read_retryable_headers(request.http_request.headers): + if (not _has_read_retryable_headers(request.http_request.headers) or + _has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: self.counter += 1 + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) @@ -375,10 +395,12 @@ def send(self, request): raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: self.counter += 1 + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) @@ -414,9 +436,11 @@ async def send(self, request): :raises ~azure.cosmos.exceptions.CosmosClientTimeoutError: Specified timeout exceeded. :raises ~azure.core.exceptions.ClientAuthenticationError: Authentication failed. """ + self.counter = 0 absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - self.counter = 0 + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -424,15 +448,18 @@ async def send(self, request): while retry_active: start_time = time.time() try: - # raise the passed in exception for the passed in resource + operation combination if request.http_request.headers.get( http_constants.HttpHeaders.ThinClientProxyResourceType) == self.resource_type: self.request_endpoints.append(request.http_request.url) if self.error: raise self.error - _retry_utility._configure_timeout(request, absolute_timeout, per_request_timeout) + _configure_timeout(request, absolute_timeout, per_request_timeout) response = await self.next.send(request) break + except ClientAuthenticationError: # pylint:disable=try-except-raise + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise except exceptions.CosmosClientTimeoutError as timeout_error: timeout_error.inner_exception = retry_error timeout_error.response = response @@ -442,40 +469,57 @@ async def send(self, request): retry_error = err # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies - if retry_settings['connect'] > 0: - self.counter += 1 - retry_active = self.increment(retry_settings, response=request, error=err) - if retry_active: - await self.sleep(retry_settings, request.context.transport) - continue - raise err - except ServiceResponseError as err: - retry_error = err - # Since this is ClientConnectionError, it is safe to be retried on both read and write requests - from aiohttp.client_exceptions import ( - ClientConnectionError) # pylint: disable=networking-import-outside-azure-core-transport - if isinstance(err.inner_exception, ClientConnectionError) or _retry_utility_async._has_read_retryable_headers(request.http_request.headers): - # This logic is based on the _retry.py file from azure-core - if retry_settings['read'] > 0: + if (not _has_database_account_header(request.http_request.headers) + and not request_params.healthy_tentative_location): + if retry_settings['connect'] > 0: self.counter += 1 + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) continue raise err + except ServiceResponseError as err: + retry_error = err + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): + raise err + # Since this is ClientConnectionError, it is safe to be retried on both read and write requests + try: + # pylint: disable=networking-import-outside-azure-core-transport + from aiohttp.client_exceptions import ( + ClientConnectionError) + if (isinstance(err.inner_exception, ClientConnectionError) + or _has_read_retryable_headers(request.http_request.headers)): + # This logic is based on the _retry.py file from azure-core + if retry_settings['read'] > 0: + self.counter += 1 + await global_endpoint_manager.record_failure(request_params) + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + await self.sleep(retry_settings, request.context.transport) + continue + except ImportError: + raise err # pylint: disable=raise-missing-from + raise err except CosmosHttpResponseError as err: raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: - retry_active = self.increment(retry_settings, response=request, error=err) self.counter += 1 + retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) continue raise err + finally: + end_time = time.time() + if absolute_timeout: + absolute_timeout -= (end_time - start_time) self.update_context(response.context, retry_settings) return response diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index 90fe5a22eaea..43f847ad45c5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -5,12 +5,11 @@ """End-to-end test. """ -import json -import os.path import time import unittest import urllib.parse as urllib import uuid +import os import pytest import requests @@ -19,7 +18,6 @@ from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse from urllib3.util.retry import Retry -import azure.cosmos._base as base import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.documents as documents import azure.cosmos.exceptions as exceptions @@ -47,7 +45,7 @@ def send(self, *args, **kwargs): response = RequestsTransportResponse(None, output) return response - +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosLong class TestCRUDOperations(unittest.TestCase): """Python CRUD Tests. @@ -75,13 +73,16 @@ def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): @classmethod def setUpClass(cls): + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey) + cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, multiple_write_locations=use_multiple_write_locations) cls.databaseForTest = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID) def test_partitioned_collection_document_crud_and_query(self): @@ -1143,7 +1144,7 @@ def test_client_request_timeout(self): container = databaseForTest.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) - async def test_read_timeout_async(self): + def test_read_timeout_async(self): connection_policy = documents.ConnectionPolicy() # making timeout 0 ms to make sure it will throw connection_policy.DBAReadTimeout = 0.000000000001 diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index ca6cfad8287d..7124c16e88b0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -4,10 +4,13 @@ """End-to-end test. """ +import asyncio +import os import time import unittest import urllib.parse as urllib import uuid +from asyncio import sleep import pytest import requests @@ -42,7 +45,7 @@ async def send(self, *args, **kwargs): response = AsyncioRequestsTransportResponse(None, output) return response - +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosLong class TestCRUDOperationsAsync(unittest.IsolatedAsyncioTestCase): """Python CRUD Tests. @@ -78,7 +81,10 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - self.client = CosmosClient(self.host, self.masterKey) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True + self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=use_multiple_write_locations) self.database_for_test = self.client.get_database_client(self.configs.TEST_DATABASE_ID) async def asyncTearDown(self): @@ -996,6 +1002,7 @@ async def test_query_iterable_functionality_async(self): doc1 = await collection.upsert_item(body={'id': 'doc1', 'prop1': 'value1'}) doc2 = await collection.upsert_item(body={'id': 'doc2', 'prop1': 'value2'}) doc3 = await collection.upsert_item(body={'id': 'doc3', 'prop1': 'value3'}) + await asyncio.sleep(1) resources = { 'coll': collection, 'doc1': doc1, @@ -1124,6 +1131,7 @@ async def test_get_resource_with_dictionary_and_object_async(self): assert read_container.id == created_container.id created_item = await created_container.create_item({'id': '1' + str(uuid.uuid4()), 'pk': 'pk'}) + await sleep(5) # read item with id read_item = await created_container.read_item(item=created_item['id'], partition_key=created_item['pk']) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 041e8f63789e..829f6a163745 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -10,6 +10,7 @@ from azure.cosmos import CosmosClient + class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() @@ -48,11 +49,6 @@ class TestDataType: L2 = "West US" L3 = "East US 2" -# L0 = "Default" -# L1 = "East US 2" -# L2 = "East US" -# L3 = "West US 2" - CLIENT_ONLY_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No excluded location @@ -123,22 +119,23 @@ def read_item_test_data(): ] return get_test_data_with_expected_output(client_only_output_data, client_and_request_output_data) + def write_item_test_data(): client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #2 - [L1] #3 + [L1], # 0 + [L2], # 1 + [L0], # 2 + [L1], # 3 ] client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L0], # 3 + [L0], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 ] return get_test_data_with_expected_output(client_only_output_data, client_and_request_output_data) @@ -355,7 +352,7 @@ def test_create_item(self, test_data): # Single write verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) def test_patch_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -380,7 +377,7 @@ def test_patch_item(self, test_data): # get location from mock_handler verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) def test_execute_item_batch(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index db85ebd08121..1b2928de217e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -10,10 +10,12 @@ import pytest_asyncio from azure.cosmos.aio import CosmosClient +from azure.cosmos.partition_key import PartitionKey from test_excluded_locations import (TestDataType, set_test_data_type, read_item_test_data, write_item_test_data, read_and_write_item_test_data, verify_endpoint) + class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() @@ -74,9 +76,9 @@ async def setup_and_teardown_async(): @pytest.mark.cosmosMultiRegion @pytest.mark.asyncio @pytest.mark.usefixtures("setup_and_teardown_async") -class TestExcludedLocations: +class TestExcludedLocationsAsync: @pytest.mark.parametrize('test_data', read_item_test_data()) - async def test_read_item(self, test_data): + async def test_read_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -97,7 +99,7 @@ async def test_read_item(self, test_data): verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) - async def test_read_all_items(self, test_data): + async def test_read_all_items_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -118,7 +120,7 @@ async def test_read_all_items(self, test_data): verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) - async def test_query_items_with_partition_key(self, test_data): + async def test_query_items_with_partition_key_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -140,7 +142,7 @@ async def test_query_items_with_partition_key(self, test_data): verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) - async def test_query_items_with_query_plan(self, test_data): + async def test_query_items_with_query_plan_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -162,7 +164,7 @@ async def test_query_items_with_query_plan(self, test_data): verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) - async def test_query_items_change_feed(self, test_data): + async def test_query_items_change_feed_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -250,8 +252,8 @@ async def test_create_item(self, test_data): # get location from mock_handler verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) - async def test_patch_item(self, test_data): + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) + async def test_patch_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -278,8 +280,8 @@ async def test_patch_item(self, test_data): # get location from mock_handler verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) - async def test_execute_item_batch(self, test_data): + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) + async def test_execute_item_batch_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -312,7 +314,7 @@ async def test_execute_item_batch(self, test_data): verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) @pytest.mark.parametrize('test_data', write_item_test_data()) - async def test_delete_item(self, test_data): + async def test_delete_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data diff --git a/sdk/cosmos/azure-cosmos/tests/test_globaldb.py b/sdk/cosmos/azure-cosmos/tests/test_globaldb.py index 3fc83ef68f0d..e700ab5e9b1c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_globaldb.py +++ b/sdk/cosmos/azure-cosmos/tests/test_globaldb.py @@ -429,7 +429,8 @@ def test_global_db_service_request_errors(self): cosmos_client.CosmosClient(self.host, self.masterKey, connection_retry_policy=mock_retry_policy) pytest.fail("Exception was not raised") except ServiceRequestError: - assert mock_retry_policy.counter == 3 + # Database account calls should not be retried in connection retry policy + assert mock_retry_policy.counter == 0 def test_global_db_endpoint_discovery_retry_policy_mock(self): client = cosmos_client.CosmosClient(self.host, self.masterKey) diff --git a/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py b/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py index 9b2a221880b1..a7aee626bfb8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py +++ b/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py @@ -1,7 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import json import unittest import pytest @@ -13,13 +12,16 @@ import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import _retry_utility +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker import \ + _GlobalPartitionEndpointManagerForCircuitBreaker from azure.cosmos.http_constants import StatusCodes location_changed = False -class MockGlobalEndpointManager: +class MockGlobalEndpointManager(_GlobalPartitionEndpointManagerForCircuitBreaker): def __init__(self, client): + super(MockGlobalEndpointManager, self).__init__(client) self.Client = client self.DefaultEndpoint = client.url_connection self._ReadEndpoint = client.url_connection @@ -73,10 +75,10 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self._ReadEndpoint - def resolve_service_endpoint(self, request): + def resolve_service_endpoint_for_partition(self, request, pk_range_wrapper): return - def refresh_endpoint_list(self): + def refresh_endpoint_list(self, database_account, **kwargs): return def can_use_multiple_write_locations(self, request): @@ -150,20 +152,6 @@ def tearDown(self): global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.OriginalGetDatabaseAccountStub _retry_utility.ExecuteFunction = self.OriginalExecuteFunction - def MockExecuteFunction(self, function, *args, **kwargs): - global location_changed - - if self.endpoint_discovery_retry_count == 2: - _retry_utility.ExecuteFunction = self.OriginalExecuteFunction - return json.dumps([{'id': 'mock database'}]), None - else: - self.endpoint_discovery_retry_count += 1 - location_changed = True - raise exceptions.CosmosHttpResponseError( - status_code=StatusCodes.FORBIDDEN, - message="Forbidden", - response=test_config.FakeResponse({'x-ms-substatus': 3})) - def MockGetDatabaseAccountStub(self, endpoint): raise exceptions.CosmosHttpResponseError( status_code=StatusCodes.INTERNAL_SERVER_ERROR, message="Internal Server Error") @@ -176,6 +164,8 @@ def test_global_db_endpoint_discovery_retry_policy(self): TestGlobalDBMock.masterKey, consistency_level="Session", connection_policy=connection_policy) + write_location_client.client_connection._global_endpoint_manager = MockGlobalEndpointManager(write_location_client.client_connection) + write_location_client.client_connection._global_endpoint_manager.refresh_endpoint_list(None) self.assertEqual(write_location_client.client_connection.WriteEndpoint, TestGlobalDBMock.write_location_host) @@ -188,6 +178,8 @@ def test_global_db_database_account_unavailable(self): client = cosmos_client.CosmosClient(TestGlobalDBMock.host, TestGlobalDBMock.masterKey, consistency_level="Session", connection_policy=connection_policy) + client.client_connection._global_endpoint_manager = MockGlobalEndpointManager(client.client_connection) + client.client_connection._global_endpoint_manager.refresh_endpoint_list(None) self.assertEqual(client.client_connection.WriteEndpoint, TestGlobalDBMock.write_location_host) self.assertEqual(client.client_connection.ReadEndpoint, TestGlobalDBMock.write_location_host) @@ -195,7 +187,7 @@ def test_global_db_database_account_unavailable(self): global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub client.client_connection.DatabaseAccountAvailable = False - client.client_connection._global_endpoint_manager.refresh_endpoint_list() + client.client_connection._global_endpoint_manager.refresh_endpoint_list(None) self.assertEqual(client.client_connection.WriteEndpoint, TestGlobalDBMock.host) self.assertEqual(client.client_connection.ReadEndpoint, TestGlobalDBMock.host) diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index b619339525d7..887be44f2273 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -126,8 +126,8 @@ def test_resolve_request_endpoint_preferred_regions(self): lc = refresh_location_cache([location1_name, location3_name, location4_name], True) db_acc = create_database_account(True) lc.perform_on_database_account_read(db_acc) - write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) - read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read, None) # resolve both document requests with all regions available write_doc_resolved = lc.resolve_service_endpoint(write_doc_request) @@ -215,9 +215,9 @@ def test_get_applicable_regional_endpoints_excluded_regions(self, test_type): location_cache.perform_on_database_account_read(database_account) # Init requests and set excluded regions on requests - write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) write_doc_request.excluded_locations = excluded_locations_on_requests - read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read, None) read_doc_request.excluded_locations = excluded_locations_on_requests # Test if read endpoints were correctly filtered on client level @@ -247,7 +247,7 @@ def test_set_excluded_locations_for_requests(self): options: Mapping[str, Any] = {"excludedLocations": excluded_locations} expected_excluded_locations = excluded_locations - read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) read_doc_request.set_excluded_location_from_options(options) actual_excluded_locations = read_doc_request.excluded_locations assert actual_excluded_locations == expected_excluded_locations @@ -262,7 +262,7 @@ def test_set_excluded_locations_for_requests(self): "If you want to remove all excluded locations, try passing an empty list.") with pytest.raises(ValueError) as e: options: Mapping[str, Any] = {"excludedLocations": None} - doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) doc_request.set_excluded_location_from_options(options) assert str( e.value) == expected_error_message diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py new file mode 100644 index 000000000000..be6b3f0f9ab6 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -0,0 +1,367 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import os +import unittest +import uuid +from time import sleep + +import pytest +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import _location_cache, _partition_health_tracker +from azure.cosmos import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport import FaultInjectionTransport +from test_per_partition_circuit_breaker_mm_async import DELETE, CREATE, UPSERT, REPLACE, PATCH, BATCH, validate_response_uri, READ, \ + QUERY_PK, QUERY, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS, REGION_1, REGION_2, \ + write_operations_and_errors, validate_unhealthy_partitions, read_operations_and_errors, PK_VALUE, operations, \ + create_doc +from test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY + +def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == CREATE: + resp = fault_injection_container.create_item(body=doc) + elif operation == UPSERT: + resp = fault_injection_container.upsert_item(body=doc) + elif operation == REPLACE: + container.create_item(body=doc) + sleep(1) + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + resp = fault_injection_container.replace_item(item=doc['id'], body=new_doc) + elif operation == DELETE: + container.create_item(body=doc) + sleep(1) + resp = fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == PATCH: + container.create_item(body=doc) + sleep(1) + operations = [{"op": "incr", "path": "/company", "value": 3}] + resp = fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) + elif operation == BATCH: + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + resp = fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) + # this will need to be emulator only + elif operation == DELETE_ALL_ITEMS_BY_PARTITION_KEY: + container.create_item(body=doc) + resp = fault_injection_container.delete_all_items_by_partition_key(pk) + if resp: + validate_response_uri(resp, expected_uri) + +def perform_read_operation(operation, container, doc_id, pk, expected_uri): + if operation == READ: + read_resp = container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + elif operation == QUERY_PK: + # partition key filtered query + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == QUERY: + # cross partition query + query = "SELECT * FROM c WHERE c.id = @id" + for item in container.query_items(query=query): + assert item['id'] == doc_id + elif operation == CHANGE_FEED: + for _ in container.query_items_change_feed(): + pass + elif operation == CHANGE_FEED_PK: + # partition key filtered change feed + for _ in container.query_items_change_feed(partition_key=pk): + pass + elif operation == CHANGE_FEED_EPK: + # partition key filtered by feed range + feed_range = container.feed_range_from_partition_key(partition_key=pk) + for _ in container.query_items_change_feed(feed_range=feed_range): + pass + elif operation == READ_ALL_ITEMS: + for _ in container.read_all_items(): + pass + +@pytest.mark.cosmosCircuitBreaker +class TestPerPartitionCircuitBreakerMM: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, + preferred_locations=[REGION_1, REGION_2], + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_consecutive_failure_threshold(self, write_operation, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(4): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + # writes should now succeed because going to the other region + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + @pytest.mark.cosmosCircuitBreakerMultiRegion + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + def test_read_consecutive_failure_threshold(self, read_operation, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) + + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # create a document to read + container.create_item(body=doc) + + # reads should fail over and only the relevant partition should be marked as unavailable + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + # partition should not have been marked unavailable after one error + validate_unhealthy_partitions(global_endpoint_manager, 0) + + for i in range(10): + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 5 + else: + expected_unhealthy_partitions = 1 + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + # remove faults and reduce initial recover time and perform a read + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_failure_rate_threshold(self, write_operation, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 1) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + @pytest.mark.cosmosCircuitBreakerMultiRegion + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + def test_read_failure_rate_threshold(self, read_operation, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) + container.upsert_item(body=doc) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + if isinstance(error, ServiceResponseError): + # service response error retries in region 3 additional times before failing over + num_operations = 2 + else: + num_operations = 8 + for i in range(num_operations): + validate_unhealthy_partitions(global_endpoint_manager, 0) + # read will fail and retry in other region + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 5 + else: + expected_unhealthy_partitions = 1 + + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + def setup_info(self, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransport() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + fault_injection_container = custom_setup['col'] + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) + container = setup['col'] + return container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate + + @pytest.mark.parametrize("read_operation, write_operation", operations()) + def test_service_request_error(self, read_operation, write_operation): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + error_lambda = lambda r: FaultInjectionTransport.error_region_down() + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) + container.upsert_item(body=doc) + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + # recover partition + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_region_down()) + + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py new file mode 100644 index 000000000000..095bd25bbb8b --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -0,0 +1,512 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any, List + +import pytest +from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import _location_cache, _partition_health_tracker +from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport_async import FaultInjectionTransportAsync + +REGION_1 = "West US 3" +REGION_2 = "West US" +CHANGE_FEED = "changefeed" +CHANGE_FEED_PK = "changefeed_pk" +CHANGE_FEED_EPK = "changefeed_epk" +READ = "read" +CREATE = "create" +READ_ALL_ITEMS = "read_all_items" +DELETE_ALL_ITEMS_BY_PARTITION_KEY = "delete_all_items_by_partition_key" +QUERY = "query" +QUERY_PK = "query_pk" +BATCH = "batch" +UPSERT = "upsert" +REPLACE = "replace" +PATCH = "patch" +DELETE = "delete" +PK_VALUE = "pk1" + + +COLLECTION = "created_collection" + +def create_errors(): + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + return errors + +def write_operations_and_errors(): + write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH] + errors = create_errors() + params = [] + for write_operation in write_operations: + for error in errors: + params.append((write_operation, error)) + + return params + +def create_doc(): + return {'id': str(uuid.uuid4()), + 'pk': PK_VALUE, + 'name': 'sample document', + 'key': 'value'} + +def read_operations_and_errors(): + read_operations = [READ, QUERY_PK, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK] + errors = create_errors() + params = [] + for read_operation in read_operations: + for error in errors: + params.append((read_operation, error)) + + return params + +def operations(): + write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH] + read_operations = [READ, QUERY_PK, CHANGE_FEED_PK, CHANGE_FEED_EPK] + operations = [] + for i, write_operation in enumerate(write_operations): + operations.append((read_operations[i % len(read_operations)], write_operation)) + + return operations + +def validate_response_uri(response, expected_uri): + request = response.get_response_headers()["_request"] + assert request.url.startswith(expected_uri) + +async def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == CREATE: + resp = await fault_injection_container.create_item(body=doc) + elif operation == UPSERT: + resp = await fault_injection_container.upsert_item(body=doc) + elif operation == REPLACE: + await container.create_item(body=doc) + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + await asyncio.sleep(1) + resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) + elif operation == DELETE: + await container.create_item(body=doc) + await asyncio.sleep(1) + resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == PATCH: + await container.create_item(body=doc) + await asyncio.sleep(1) + operations = [{"op": "incr", "path": "/company", "value": 3}] + resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) + elif operation == BATCH: + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) + # this will need to be emulator only + elif operation == DELETE_ALL_ITEMS_BY_PARTITION_KEY: + await container.create_item(body=doc) + resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + if resp: + validate_response_uri(resp, expected_uri) + +async def perform_read_operation(operation, container, doc_id, pk, expected_uri): + if operation == READ: + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + elif operation == QUERY_PK: + # partition key filtered query + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == QUERY: + # cross partition query + query = "SELECT * FROM c WHERE c.id = @id" + async for item in container.query_items(query=query): + assert item['id'] == doc_id + elif operation == CHANGE_FEED: + async for _ in container.query_items_change_feed(): + pass + elif operation == CHANGE_FEED_PK: + # partition key filtered change feed + async for _ in container.query_items_change_feed(partition_key=pk): + pass + elif operation == CHANGE_FEED_EPK: + # partition key filtered by feed range + feed_range = await container.feed_range_from_partition_key(partition_key=pk) + async for _ in container.query_items_change_feed(feed_range=feed_range): + pass + elif operation == READ_ALL_ITEMS: + async for _ in container.read_all_items(): + pass + +def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_consecutive_failure_count < 5 + + assert unhealthy_partitions == expected_unhealthy_partitions + +async def cleanup_method(initialized_objects: List[Dict[str, Any]]): + for obj in initialized_objects: + method_client: CosmosClient = obj["client"] + await method_client.close() + +@pytest.mark.cosmosCircuitBreaker +@pytest.mark.asyncio +class TestPerPartitionCircuitBreakerMMAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(default_endpoint, self.master_key, + preferred_locations=[REGION_1, REGION_2], + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_consecutive_failure_threshold_async(self, write_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(4): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + # writes should now succeed because going to the other region + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) + + async def setup_info(self, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransportAsync() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + + @pytest.mark.cosmosCircuitBreakerMultiRegion + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_consecutive_failure_threshold_async(self, read_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # create a document to read + await container.create_item(body=doc) + + # reads should fail over and only the relevant partition should be marked as unavailable + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + # partition should not have been marked unavailable after one error + validate_unhealthy_partitions(global_endpoint_manager, 0) + + for i in range(10): + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 5 + else: + expected_unhealthy_partitions = 1 + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + # remove faults and reduce initial recover time and perform a read + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_failure_rate_threshold_async(self, write_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + await fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 1) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) + + @pytest.mark.cosmosCircuitBreakerMultiRegion + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_failure_rate_threshold_async(self, read_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + await container.upsert_item(body=doc) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + if isinstance(error, ServiceResponseError): + # service response error retries in region 3 additional times before failing over + num_operations = 2 + else: + num_operations = 8 + for i in range(num_operations): + validate_unhealthy_partitions(global_endpoint_manager, 0) + # read will fail and retry in other region + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 5 + else: + expected_unhealthy_partitions = 1 + + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) + + @pytest.mark.parametrize("read_operation, write_operation", operations()) + async def test_service_request_error_async(self, read_operation, write_operation): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down()) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + await container.upsert_item(body=doc) + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + # recover partition + # remove faults and reduce initial recover time and perform a read + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + # per partition circuit breaker should not regress connection timeouts marking the region as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + + await cleanup_method([custom_setup, setup]) + + + # send 15 write concurrent requests when trying to recover + # verify that only one failed + async def test_recovering_only_fails_one_requests_async(self): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, CosmosHttpResponseError( + status_code=502, + message="Some envoy error."))) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + for i in range(5): + with pytest.raises(CosmosHttpResponseError): + await fault_injection_container.create_item(body=doc) + + + number_of_errors = 0 + + async def concurrent_upsert(): + nonlocal number_of_errors + doc = {'id': str(uuid.uuid4()), + 'pk': PK_VALUE, + 'name': 'sample document', + 'key': 'value'} + try: + await fault_injection_container.upsert_item(doc) + except CosmosHttpResponseError as e: + number_of_errors += 1 + + # attempt to recover partition + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + try: + tasks = [] + for i in range(15): + tasks.append(concurrent_upsert()) + await asyncio.gather(*tasks) + assert number_of_errors == 1 + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + await cleanup_method([custom_setup, setup]) + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py new file mode 100644 index 000000000000..34ce4109dd61 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -0,0 +1,171 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import os +import unittest +import uuid + +import pytest +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import _partition_health_tracker, _location_cache +from azure.cosmos import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport import FaultInjectionTransport +from test_per_partition_circuit_breaker_mm import perform_write_operation, perform_read_operation +from test_per_partition_circuit_breaker_mm_async import create_doc, PK_VALUE, write_operations_and_errors, \ + operations, REGION_2, REGION_1 +from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions + +COLLECTION = "created_collection" + +@pytest.mark.cosmosCircuitBreakerMultiRegion +class TestPerPartitionCircuitBreakerSmMrr: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=[REGION_1, REGION_2], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + def setup_info(self, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransport() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_consecutive_failure_threshold(self, write_operation, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + validate_unhealthy_partitions(global_endpoint_manager, 0) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + perform_write_operation( + write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_failure_rate_threshold(self, write_operation, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + @pytest.mark.parametrize("read_operation, write_operation", operations()) + def test_service_request_error(self, read_operation, write_operation): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + error_lambda = lambda r: FaultInjectionTransport.error_region_down() + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + container.upsert_item(body=doc) + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # there shouldn't be partition marked as unavailable + validate_unhealthy_partitions(global_endpoint_manager, 0) + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + # recover partition + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_region_down()) + + # The global endpoint would be used for the write operation + expected_uri = self.host + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py new file mode 100644 index 000000000000..665eff84aa0a --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -0,0 +1,206 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import _partition_health_tracker, _location_cache +from azure.cosmos._partition_health_tracker import UNHEALTHY_TENTATIVE, UNHEALTHY, HEALTH_STATUS +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport_async import FaultInjectionTransportAsync +from test_per_partition_circuit_breaker_mm_async import perform_write_operation, create_doc, PK_VALUE, \ + write_operations_and_errors, cleanup_method, perform_read_operation, operations, REGION_2, REGION_1 + +COLLECTION = "created_collection" + +def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + + else: + assert health_info.read_consecutive_failure_count < 10 + # single region write account should never track write failures + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + + assert unhealthy_partitions == expected_unhealthy_partitions + + +@pytest.mark.cosmosCircuitBreakerMultiRegion +@pytest.mark.asyncio +class TestPerPartitionCircuitBreakerSmMrrAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=[REGION_1, REGION_2], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + async def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + await method_client.close() + + async def setup_info(self, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransportAsync() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_consecutive_failure_threshold_async(self, write_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + validate_unhealthy_partitions(global_endpoint_manager, 0) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await perform_write_operation( + write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) + + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_failure_rate_threshold_async(self, write_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + await fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) + + @pytest.mark.parametrize("read_operation, write_operation", operations()) + async def test_service_request_error_async(self, read_operation, write_operation): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down()) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + await container.upsert_item(body=doc) + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + + # recover partition + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + # The global endpoint would be used for the write operation in single region write + expected_uri = self.host + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + await cleanup_method([custom_setup, setup]) + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 2a99263ed457..aa17116b2f39 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -17,6 +17,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -32,7 +33,10 @@ class TestQuery(unittest.TestCase): @classmethod def setUpClass(cls): - cls.client = cosmos_client.CosmosClient(cls.host, cls.credential) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True + cls.client = cosmos_client.CosmosClient(cls.host, cls.credential, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) if cls.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index c89b2d944110..11cbc8768fde 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import asyncio import os import unittest import uuid @@ -18,6 +18,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestQueryAsync(unittest.IsolatedAsyncioTestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -34,6 +35,9 @@ class TestQueryAsync(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): + cls.use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + cls.use_multiple_write_locations = True if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( @@ -42,7 +46,7 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - self.client = CosmosClient(self.host, self.masterKey) + self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=self.use_multiple_write_locations) self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) if self.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" @@ -56,6 +60,7 @@ async def test_first_and_last_slashes_trimmed_for_query_string_async(self): doc_id = 'myId' + str(uuid.uuid4()) document_definition = {'pk': 'pk', 'id': doc_id} await created_collection.create_item(body=document_definition) + await asyncio.sleep(1) query = 'SELECT * from c' query_iterable = created_collection.query_items( @@ -75,6 +80,7 @@ async def test_populate_query_metrics_async(self): doc_id = 'MyId' + str(uuid.uuid4()) document_definition = {'pk': 'pk', 'id': doc_id} await created_collection.create_item(body=document_definition) + await asyncio.sleep(1) query = 'SELECT * from c' query_iterable = created_collection.query_items( @@ -103,6 +109,7 @@ async def test_populate_index_metrics_async(self): doc_id = 'MyId' + str(uuid.uuid4()) document_definition = {'pk': 'pk', 'id': doc_id} await created_collection.create_item(body=document_definition) + await asyncio.sleep(1) query = 'SELECT * from c' query_iterable = created_collection.query_items( diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py index 2e518d075c73..1ee7b550f13b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py @@ -16,7 +16,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey - +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestCrossPartitionQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -39,7 +39,10 @@ def setUpClass(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True + cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) if cls.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py index d9a5b4b251b7..9a6961d1cd03 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py @@ -16,6 +16,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.partition_key import PartitionKey +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestQueryCrossPartitionAsync(unittest.IsolatedAsyncioTestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -40,7 +41,10 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - self.client = CosmosClient(self.host, self.masterKey) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True + self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=use_multiple_write_locations) self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) self.created_container = await self.created_db.create_container( self.TEST_CONTAINER_ID, diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py index 86f33f658204..d779ac2472e0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py @@ -295,7 +295,7 @@ async def test_hybrid_search_weighted_reciprocal_rank_fusion_async(self): query = "SELECT c.index, c.title FROM c " \ "ORDER BY RANK RRF(FullTextScore(c.text, 'United States'), VectorDistance(c.vector, {}), [1,1]) " \ "OFFSET 0 LIMIT 10".format(item_vector) - results = self.test_container.query_items(query, enable_cross_partition_query=True) + results = self.test_container.query_items(query) result_list = [res async for res in results] assert len(result_list) == 10 result_list = [res['index'] for res in result_list] diff --git a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py index dd9870a99231..fccf19225cfe 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py @@ -46,7 +46,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] - self.ExcludedLocations = None + self.ExcludedLocations = [] self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False diff --git a/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py b/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py index b034287fbd6d..bb2a986e6471 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py +++ b/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py @@ -143,8 +143,8 @@ def test_retry_policy_does_not_mark_null_locations_unavailable(self): endpoint_manager.mark_endpoint_unavailable_for_read = self._mock_mark_endpoint_unavailable_for_read self.original_mark_endpoint_unavailable_for_write_function = endpoint_manager.mark_endpoint_unavailable_for_write endpoint_manager.mark_endpoint_unavailable_for_write = self._mock_mark_endpoint_unavailable_for_write - self.original_resolve_service_endpoint = endpoint_manager.resolve_service_endpoint - endpoint_manager.resolve_service_endpoint = self._mock_resolve_service_endpoint + self.original_resolve_service_endpoint = endpoint_manager.resolve_service_endpoint_for_partition + endpoint_manager.resolve_service_endpoint_for_partition = self._mock_resolve_service_endpoint # Read and write counters count the number of times the endpoint manager's # mark_endpoint_unavailable_for_read() and mark_endpoint_unavailable_for_read() diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index f9418a63aa9d..8edaff6fe68d 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -25,6 +25,40 @@ } } }, + { + "CircuitBreakerMultiWriteTestConfig": { + "Ubuntu2004_313_circuit_breaker": { + "OSVmImage": "env:LINUXVMIMAGE", + "Pool": "env:LINUXPOOL", + "PythonVersion": "3.13", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosCircuitBreaker" + } + }, + "ArmConfig": { + "MultiMaster": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = 'True' }" + } + } + }, + { + "CircuitBreakerMultiRegionTestConfig": { + "Ubuntu2004_39_circuit_breaker": { + "OSVmImage": "env:LINUXVMIMAGE", + "Pool": "env:LINUXPOOL", + "PythonVersion": "3.9", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosCircuitBreakerMultiRegion" + } + }, + "ArmConfig": { + "MultiRegion": { + "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = 'True' }" + } + } + }, { "MacTestConfig": { "macos311_search_query": { diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 88abe955f8d8..735c1a0e66ee 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -12,6 +12,9 @@ param enableMultipleRegions bool = false @description('Location for the Cosmos DB account.') param location string = resourceGroup().location +@description('Whether Per Partition Circuit Breaker should be enabled.') +param circuitBreakerEnabled string = 'False' + @description('The api version to be used by Bicep to create resources') param apiVersion string = '2023-04-15' @@ -101,6 +104,6 @@ resource accountName_roleAssignmentId 'Microsoft.DocumentDB/databaseAccounts/sql } } - +output AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER string = circuitBreakerEnabled output ACCOUNT_HOST string = reference(resourceId, apiVersion).documentEndpoint output ACCOUNT_KEY string = listKeys(resourceId, apiVersion).primaryMasterKey