From ff20cf9a2afcdf20c49a9e79434763314a77b107 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 19:03:52 -0800 Subject: [PATCH 001/152] change default read timeout --- sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py | 2 ++ .../azure-cosmos/azure/cosmos/aio/_asynchronous_request.py | 2 ++ sdk/cosmos/azure-cosmos/azure/cosmos/documents.py | 1 + 3 files changed, 5 insertions(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 1c90bfa57150..68e37caf1d9d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -93,6 +93,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: + # always override database account call timeouts + read_timeout = connection_policy.DBAReadTimeout connection_timeout = connection_policy.DBAConnectionTimeout if client_timeout is not None: kwargs['timeout'] = client_timeout - (time.time() - start_time) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 377cb5d406b1..81430d8df42c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -62,6 +62,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: await global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: + # always override database account call timeouts + read_timeout = connection_policy.DBAReadTimeout connection_timeout = connection_policy.DBAConnectionTimeout if client_timeout is not None: kwargs['timeout'] = client_timeout - (time.time() - start_time) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index fcf855f56921..8093b2f71fc0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -332,6 +332,7 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes __defaultRequestTimeout: int = 5 # seconds __defaultDBAConnectionTimeout: int = 3 # seconds __defaultReadTimeout: int = 65 # seconds + __defaultDBAReadTimeout: int = 3 # seconds __defaultMaxBackoff: int = 1 # seconds def __init__(self) -> None: From 40e43c40193fa7a7d8f87c2c570d9e82686c6185 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 20:02:53 -0800 Subject: [PATCH 002/152] fix tests --- sdk/cosmos/azure-cosmos/azure/cosmos/documents.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 8093b2f71fc0..57d6e75be534 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -339,6 +339,7 @@ def __init__(self) -> None: self.RequestTimeout: int = self.__defaultRequestTimeout self.DBAConnectionTimeout: int = self.__defaultDBAConnectionTimeout self.ReadTimeout: int = self.__defaultReadTimeout + self.DBAReadTimeout: int = self.__defaultDBAReadTimeout self.MaxBackoff: int = self.__defaultMaxBackoff self.ConnectionMode: int = ConnectionMode.Gateway self.SSLConfiguration: Optional[SSLConfiguration] = None From aefe30b42efa1d02dd282468e70276333da0b127 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 21:24:38 -0800 Subject: [PATCH 003/152] Add read timeout tests for database account calls --- sdk/cosmos/azure-cosmos/test/test_crud.py | 9 ++++++++- sdk/cosmos/azure-cosmos/test/test_crud_async.py | 9 +++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud.py b/sdk/cosmos/azure-cosmos/test/test_crud.py index d41246f5d6cd..a96654cec6c0 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud.py @@ -1821,7 +1821,14 @@ def test_client_request_timeout(self): container = databaseForTest.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) - + async def test_read_timeout_async(self): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.DBAReadTimeout = 0.000000000001 + with self.assertRaises(ServiceResponseError): + # this will make a get database account call + with cosmos_client.CosmosClient(self.host, self.masterKey, connection_policy=connection_policy): + print('initialization') def test_client_request_timeout_when_connection_retry_configuration_specified(self): connection_policy = documents.ConnectionPolicy() diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index 1c40afc3edfa..00517d23ef8e 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -1689,6 +1689,15 @@ async def test_client_request_timeout_async(self): await container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) print('Async initialization') + async def test_read_timeout_async(self): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.DBAReadTimeout = 0.000000000001 + with self.assertRaises(ServiceResponseError): + # this will make a get database account call + async with CosmosClient(self.host, self.masterKey, connection_policy=connection_policy): + print('Async initialization') + async def test_client_request_timeout_when_connection_retry_configuration_specified_async(self): connection_policy = documents.ConnectionPolicy() # making timeout 0 ms to make sure it will throw From 9a234f87d902ea22925313864a91d7a3b864fbee Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 21:36:21 -0800 Subject: [PATCH 004/152] fix timeout retry policy --- .../azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index 145dfd947ccf..036061a17b07 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -30,7 +30,7 @@ def ShouldRetry(self, _exception): :rtype: bool """ if self.request: - if _OperationType.IsReadOnlyOperation(self.request.operation_type): + if not _OperationType.IsReadOnlyOperation(self.request.operation_type): return False if not self.connection_policy.EnableEndpointDiscovery: From 8859c9fc8972c7ba684c953b6b6cc526cdd7ca01 Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Wed, 5 Feb 2025 21:56:45 -0800 Subject: [PATCH 005/152] Fixed the timeout logic --- .../cosmos/_timeout_failover_retry_policy.py | 39 +++---------------- 1 file changed, 5 insertions(+), 34 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index 036061a17b07..aa66cd4f76e7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -18,7 +18,6 @@ def __init__(self, connection_policy, global_endpoint_manager, *args): self.global_endpoint_manager = global_endpoint_manager self.retry_count = 0 - self.location_index = 0 self.connection_policy = connection_policy self.request = args[0] if args else None @@ -29,14 +28,13 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ - if self.request: - if not _OperationType.IsReadOnlyOperation(self.request.operation_type): - return False + # we don't retry on write operations for timeouts or service unavailable + if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): + return False if not self.connection_policy.EnableEndpointDiscovery: return False - # Check if the next retry about to be done is safe if _exception.status_code == http_constants.StatusCodes.SERVICE_UNAVAILABLE and \ self.retry_count >= self._max_service_unavailable_retry_count: @@ -47,46 +45,19 @@ def ShouldRetry(self, _exception): return False if self.request: - # Update the last routed location to where this request was routed previously. - # So that we can check in location cache if we need to return the current or previous - # based on where the request was routed previously. - self.request.last_routed_location_endpoint_within_region = self.request.location_endpoint_to_route - - if _OperationType.IsReadOnlyOperation(self.request.operation_type): - # We just directly got to the next location in case of read requests - # We don't retry again on the same region for regional endpoint - location_endpoint = self.resolve_next_region_service_endpoint() - else: - self.global_endpoint_manager.swap_regional_endpoint_values(self.request) - location_endpoint = self.resolve_current_region_service_endpoint() - # This is the case where both current and previous point to the same writable endpoint - # In this case we don't want to retry again, rather failover to the next region - if self.request.last_routed_location_endpoint_within_region == location_endpoint: - location_endpoint = self.resolve_next_region_service_endpoint() - + location_endpoint = self.resolve_next_region_service_endpoint() self.request.route_to_location(location_endpoint) return True - - # This function prepares the request to go to the second endpoint in the same region - def resolve_current_region_service_endpoint(self): - # clear previous location-based routing directive - self.request.clear_route_to_location() - # resolve the next service endpoint in the same region - # since we maintain 2 endpoints per region for write operations - self.request.route_to_location_with_preferred_location_flag(self.location_index, True) - return self.global_endpoint_manager.resolve_service_endpoint(self.request) - # This function prepares the request to go to the next region def resolve_next_region_service_endpoint(self): - self.location_index += 1 # clear previous location-based routing directive self.request.clear_route_to_location() # clear the last routed endpoint within same region since we are going to a new region now self.request.last_routed_location_endpoint_within_region = None # set location-based routing directive based on retry count # ensuring usePreferredLocations is set to True for retry - self.request.route_to_location_with_preferred_location_flag(self.location_index, True) + self.request.route_to_location_with_preferred_location_flag(self.retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability return self.global_endpoint_manager.resolve_service_endpoint(self.request) From ac78da9632bf406e6ac38b4d9ea79cc03eb7d41d Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Wed, 5 Feb 2025 23:13:52 -0800 Subject: [PATCH 006/152] Fixed the timeout retry policy --- .../azure-cosmos/azure/cosmos/_retry_utility.py | 7 +++++-- .../azure/cosmos/_timeout_failover_retry_policy.py | 13 ++++--------- .../azure/cosmos/aio/_retry_utility_async.py | 5 +++-- .../azure-cosmos/azure/cosmos/http_constants.py | 1 - sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py | 2 +- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 99784facecc4..927ed7a41baa 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -131,7 +131,6 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): sub_status_code=SubStatusCodes.THROUGHPUT_OFFER_NOT_FOUND) return result except exceptions.CosmosHttpResponseError as e: - retry_policy = defaultRetry_policy if request and _has_database_account_header(request.headers): retry_policy = database_account_retry_policy # Re-assign retry policy based on error code @@ -173,8 +172,12 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code in [StatusCodes.REQUEST_TIMEOUT, StatusCodes.SERVICE_UNAVAILABLE]: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT: retry_policy = timeout_failover_retry_policy + elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + retry_policy = timeout_failover_retry_policy + else: + retry_policy = defaultRetry_policy # If none of the retry policies applies or there is no retry needed, set the # throttle related response headers and re-throw the exception back arg[0] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index aa66cd4f76e7..60f0208e6351 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -5,18 +5,17 @@ Cosmos database service. """ from azure.cosmos.documents import _OperationType -from . import http_constants class _TimeoutFailoverRetryPolicy(object): def __init__(self, connection_policy, global_endpoint_manager, *args): - self._max_retry_attempt_count = 120 - self._max_service_unavailable_retry_count = 1 - self.retry_after_in_milliseconds = 0 + self.retry_after_in_milliseconds = 500 self.args = args self.global_endpoint_manager = global_endpoint_manager + # If an account only has 1 region, then we still want to retry once on the same region + self._max_retry_attempt_count = len(self.global_endpoint_manager.location_cache.read_regional_endpoints) + 1 self.retry_count = 0 self.connection_policy = connection_policy self.request = args[0] if args else None @@ -28,17 +27,13 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ - # we don't retry on write operations for timeouts or service unavailable + # we don't retry on write operations for timeouts or any internal server errors if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): return False if not self.connection_policy.EnableEndpointDiscovery: return False - # Check if the next retry about to be done is safe - if _exception.status_code == http_constants.StatusCodes.SERVICE_UNAVAILABLE and \ - self.retry_count >= self._max_service_unavailable_retry_count: - return False self.retry_count += 1 # Check if the next retry about to be done is safe if self.retry_count >= self._max_retry_attempt_count: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 74df8ea9479f..c4be5a1afc2c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -130,7 +130,6 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg return result except exceptions.CosmosHttpResponseError as e: - retry_policy = None if request and _has_database_account_header(request.headers): retry_policy = database_account_retry_policy elif e.status_code == StatusCodes.FORBIDDEN and e.sub_status in \ @@ -171,7 +170,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code in [StatusCodes.REQUEST_TIMEOUT, StatusCodes.SERVICE_UNAVAILABLE]: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT: + retry_policy = timeout_failover_retry_policy + elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py index 8a7b57b7c93f..31a95d2600d6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py @@ -400,7 +400,6 @@ class StatusCodes: RETRY_WITH = 449 INTERNAL_SERVER_ERROR = 500 - SERVICE_UNAVAILABLE = 503 # Operation pause and cancel. These are FAKE status codes for QOS logging purpose only. OPERATION_PAUSED = 1200 diff --git a/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py b/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py index 5548b51839b3..83d5e2603e9f 100644 --- a/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py +++ b/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py @@ -166,7 +166,7 @@ def MockExecuteFunction(self, function, *args, **kwargs): def MockGetDatabaseAccountStub(self, endpoint): raise exceptions.CosmosHttpResponseError( - status_code=StatusCodes.SERVICE_UNAVAILABLE, message="Service unavailable") + status_code=StatusCodes.INTERNAL_SERVER_ERROR, message="Internal Server Error") def test_global_db_endpoint_discovery_retry_policy(self): connection_policy = documents.ConnectionPolicy() From 09aac90f49d2391bebfe29cfab1658fcb892b161 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 6 Feb 2025 02:30:58 -0800 Subject: [PATCH 007/152] Mock tests for timeout and failover retry policy --- .../test_timeout_and_failover_retry_policy.py | 135 +++++++++++++++++ ...timeout_and_failover_retry_policy_async.py | 137 ++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py create mode 100644 sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py diff --git a/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py new file mode 100644 index 000000000000..342453524246 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py @@ -0,0 +1,135 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid + +import pytest + +import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import _retry_utility, PartitionKey + +COLLECTION = "created_collection" +@pytest.fixture(scope="class") +def setup(): + if (TestTimeoutRetryPolicy.masterKey == '[YOUR_KEY_HERE]' or + TestTimeoutRetryPolicy.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = cosmos_client.CosmosClient(TestTimeoutRetryPolicy.host, TestTimeoutRetryPolicy.masterKey, consistency_level="Session", + connection_policy=TestTimeoutRetryPolicy.connectionPolicy) + created_database = client.get_database_client(TestTimeoutRetryPolicy.TEST_DATABASE_ID) + created_collection = created_database.create_container(TestTimeoutRetryPolicy.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + yield { + COLLECTION: created_collection + } + + created_database.delete_container(TestTimeoutRetryPolicy.TEST_CONTAINER_SINGLE_PARTITION_ID) + + + + +def error_codes(): + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.unittest +@pytest.mark.usefixtures("setup") +class TestTimeoutRetryPolicy: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + @pytest.mark.parametrize("error_code", error_codes()) + def test_timeout_failover_retry_policy_for_read_success(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility.ExecuteFunction + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) + _retry_utility.ExecuteFunction = mf + doc = setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + assert doc == created_document + finally: + _retry_utility.ExecuteFunction = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + def test_timeout_failover_retry_policy_for_read_failure(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility.ExecuteFunction + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) + _retry_utility.ExecuteFunction = mf + setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility.ExecuteFunction = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + def test_timeout_failover_retry_policy_for_write_failure(self, setup, error_code): + document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + self.original_execute_function = _retry_utility.ExecuteFunction + try: + # timeouts should fail immediately for writes + mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) + _retry_utility.ExecuteFunction = mf + try: + setup[COLLECTION].create_item(body=document_definition) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility.ExecuteFunction = self.original_execute_function + + + + + class MockExecuteFunction(object): + def __init__(self, org_func, num_exceptions, status_code): + self.org_func = org_func + self.counter = 0 + self.num_exceptions = num_exceptions + self.status_code = status_code + + def __call__(self, func, *args, **kwargs): + if self.counter != 0 and self.counter >= self.num_exceptions: + return self.org_func(func, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + + + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py new file mode 100644 index 000000000000..90b2f46dc651 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py @@ -0,0 +1,137 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid + +import pytest +import pytest_asyncio + +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient, _retry_utility_async + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + if (TestTimeoutRetryPolicyAsync.masterKey == '[YOUR_KEY_HERE]' or + TestTimeoutRetryPolicyAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(TestTimeoutRetryPolicyAsync.host, TestTimeoutRetryPolicyAsync.masterKey, consistency_level="Session", + connection_policy=TestTimeoutRetryPolicyAsync.connectionPolicy) + created_database = client.get_database_client(TestTimeoutRetryPolicyAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestTimeoutRetryPolicyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestTimeoutRetryPolicyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + + + + +def error_codes(): + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestTimeoutRetryPolicyAsync: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_success_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + doc = await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + assert doc == created_document + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # timeouts should fail immediately for writes + mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + try: + await setup[COLLECTION].create_item(body=document_definition) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + + + + class MockExecuteFunction(object): + def __init__(self, org_func, num_exceptions, status_code): + self.org_func = org_func + self.counter = 0 + self.num_exceptions = num_exceptions + self.status_code = status_code + + def __call__(self, func, global_endpoint_manager, *args, **kwargs): + if self.counter != 0 and self.counter >= self.num_exceptions: + return self.org_func(func, global_endpoint_manager, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + + + +if __name__ == '__main__': + unittest.main() From f22e7d21d05e55eb8cf2ff06a1bb21d6ab0658de Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 02:22:44 +0000 Subject: [PATCH 008/152] Create test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 164 +++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/test/test_dummy.py diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py new file mode 100644 index 000000000000..4fe18ef001f1 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -0,0 +1,164 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +from collections.abc import MutableMapping +import logging +from typing import Any +import unittest +import uuid + +import pytest +import pytest_asyncio + +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient, _retry_utility_async +from azure.core.rest import HttpRequest, AsyncHttpResponse +import asyncio +import aiohttp +import sys +from azure.core.pipeline.transport import AioHttpTransport + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + if (TestDummyAsync.masterKey == '[YOUR_KEY_HERE]' or + TestDummyAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + logger = logging.getLogger('azure.cosmos') + logger.setLevel("INFO") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler(sys.stdout)) + custom_transport = TestDummyAsync.FaulInjectionTransport(logger) + client = CosmosClient(TestDummyAsync.host, TestDummyAsync.masterKey, consistency_level="Session", + connection_policy=TestDummyAsync.connectionPolicy, transport=custom_transport) + created_database = client.get_database_client(TestDummyAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + + +def error_codes(): + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestDummyAsync: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_success_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + doc = await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + assert doc == created_document + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # timeouts should fail immediately for writes + mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + try: + await setup[COLLECTION].create_item(body=document_definition) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + class FaulInjectionTransport(AioHttpTransport): + def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + self.logger = logger + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # Add custom logic before sending the request + self.logger.error(f"Sending request to {request.url}") + + # Call the base class's send method to actually send the request + try: + response = await super().send(request, stream=stream, proxies=proxies, **config) + except Exception as e: + self.logger.error(f"Error: {e}") + raise + + # Add custom logic after receiving the response + self.logger.info(f"Received response with status code {response.status_code}") + + return response + + class MockExecuteFunction(object): + def __init__(self, org_func, num_exceptions, status_code): + self.org_func = org_func + self.counter = 0 + self.num_exceptions = num_exceptions + self.status_code = status_code + + def __call__(self, func, global_endpoint_manager, *args, **kwargs): + if self.counter != 0 and self.counter >= self.num_exceptions: + return self.org_func(func, global_endpoint_manager, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + +if __name__ == '__main__': + unittest.main() From dd8a466019ba9b574d093c7bffbf2d037817f2a8 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 02:24:06 +0000 Subject: [PATCH 009/152] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 4fe18ef001f1..cd469e071ef8 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -3,7 +3,7 @@ from collections.abc import MutableMapping import logging -from typing import Any +from typing import Any, Callable import unittest import uuid @@ -125,11 +125,29 @@ async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup class FaulInjectionTransport(AioHttpTransport): def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): self.logger = logger + self.faults = [] + self.requestTransformationOverrides = [] + self.responseTransformationOverrides = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): + self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) + + def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): + self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) + + def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): + self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # find the first fault Factory with matching predicate if any + firstFaultFactory = next(lambda f: f["predicate"](f["predicate"]), self.faults, None) + + # Add custom logic before sending the request - self.logger.error(f"Sending request to {request.url}") + + + self.logger.info(f"Sending request to {request.url}") # Call the base class's send method to actually send the request try: From 8ac11c5a142a8091cf5a19249a6de981782059d2 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 07:01:31 +0000 Subject: [PATCH 010/152] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 256 +++++++++------------ 1 file changed, 112 insertions(+), 144 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index cd469e071ef8..4b2fda5d3437 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -3,6 +3,8 @@ from collections.abc import MutableMapping import logging +import time +from tokenize import String from typing import Any, Callable import unittest import uuid @@ -10,6 +12,8 @@ import pytest import pytest_asyncio +from azure.cosmos.aio._container import ContainerProxy +from azure.cosmos.aio._database import DatabaseProxy import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import PartitionKey @@ -21,162 +25,126 @@ from azure.core.pipeline.transport import AioHttpTransport COLLECTION = "created_collection" -@pytest_asyncio.fixture() -async def setup(): - if (TestDummyAsync.masterKey == '[YOUR_KEY_HERE]' or - TestDummyAsync.host == '[YOUR_ENDPOINT_HERE]'): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") - - logger = logging.getLogger('azure.cosmos') - logger.setLevel("INFO") - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler(sys.stdout)) - custom_transport = TestDummyAsync.FaulInjectionTransport(logger) - client = CosmosClient(TestDummyAsync.host, TestDummyAsync.masterKey, consistency_level="Session", - connection_policy=TestDummyAsync.connectionPolicy, transport=custom_transport) - created_database = client.get_database_client(TestDummyAsync.TEST_DATABASE_ID) - created_collection = await created_database.create_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk")) - yield { - COLLECTION: created_collection - } - - await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) - await client.close() - - -def error_codes(): - return [408, 500, 502, 503] +logger = logging.getLogger('azure.cosmos') +logger.setLevel("INFO") +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +host = test_config.TestConfig.host +masterKey = test_config.TestConfig.masterKey +connectionPolicy = test_config.TestConfig.connectionPolicy +TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + +@pytest.fixture() +def setup(): + return @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") class TestDummyAsync: - host = test_config.TestConfig.host - masterKey = test_config.TestConfig.masterKey - connectionPolicy = test_config.TestConfig.connectionPolicy - TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_timeout_failover_retry_policy_for_read_success_async(self, setup, error_code): - document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), - 'pk': 'pk', - 'name': 'sample document', - 'key': 'value'} - - created_document = await setup[COLLECTION].create_item(body=document_definition) - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - try: - # should retry once and then succeed - mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) - _retry_utility_async.ExecuteFunctionAsync = mf - doc = await setup[COLLECTION].read_item(item=created_document['id'], - partition_key=created_document['pk']) - assert doc == created_document - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_timeout_failover_retry_policy_for_read_failure_async(self, setup, error_code): - document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), - 'pk': 'pk', + logger = logger + + async def setup(self, custom_transport: AioHttpTransport): + + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + if (masterKey == '[YOUR_KEY_HERE]' or + host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(host, masterKey, consistency_level="Session", + connection_policy=connectionPolicy, transport=custom_transport) + created_database: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) + created_collection = await created_database.create_container(TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + return {"client": client, "db": created_database, "col": created_collection} + + async def test_throws_injected_error(self, setup): + custom_transport = FaulInjectionTransport(logger) + idValue: str = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': idValue, + 'pk': idValue, 'name': 'sample document', 'key': 'value'} - created_document = await setup[COLLECTION].create_item(body=document_definition) - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - try: - # should retry once and then succeed - mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) - _retry_utility_async.ExecuteFunctionAsync = mf - await setup[COLLECTION].read_item(item=created_document['id'], - partition_key=created_document['pk']) - pytest.fail("Exception was not raised.") - except exceptions.CosmosHttpResponseError as err: - assert err.status_code == error_code - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup, error_code): - document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), - 'pk': 'pk', - 'name': 'sample document', - 'key': 'value'} - - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - try: - # timeouts should fail immediately for writes - mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) - _retry_utility_async.ExecuteFunctionAsync = mf - try: - await setup[COLLECTION].create_item(body=document_definition) - pytest.fail("Exception was not raised.") - except exceptions.CosmosHttpResponseError as err: - assert err.status_code == error_code - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - class FaulInjectionTransport(AioHttpTransport): - def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): - self.logger = logger - self.faults = [] - self.requestTransformationOverrides = [] - self.responseTransformationOverrides = [] - super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - - def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): - self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) - - def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): - self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) - - def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): - self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) - - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): - # find the first fault Factory with matching predicate if any - firstFaultFactory = next(lambda f: f["predicate"](f["predicate"]), self.faults, None) - - - # Add custom logic before sending the request - - + initializedObjects = await self.setup(custom_transport) + container: ContainerProxy = initializedObjects["col"] + + created_document = await container.create_item(body=document_definition) + start = time.perf_counter() + + while ((time.perf_counter() - start) < 7): + await container.read_item(idValue, partition_key=idValue) + await asyncio.sleep(2) + + created_database: DatabaseProxy = initializedObjects["db"] + await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + client: CosmosClient = initializedObjects["client"] + await client.close() + + +class FaulInjectionTransport(AioHttpTransport): + def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + self.logger = logger + self.faults = [] + self.requestTransformations = [] + self.responseTransformations = [] + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): + self.faults.append({"predicate": predicate, "apply": faultFactory}) + + def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): + self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) + + def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): + self.responseTransformations.append({ + "predicate": predicate, + "apply": responseTransformation}) + + def firstItem(self, iterable, condition=lambda x: True): + """ + Returns the first item in the `iterable` that satisfies the `condition`. + + If no item satisfies the condition, it returns None. + """ + return next((x for x in iterable if condition(x)), None) + + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # find the first fault Factory with matching predicate if any + firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) + if (firstFaultFactory != None): + self.logger.info("") + return await firstFaultFactory["apply"]() + + # apply the chain of request transformations with matching predicates if any + matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matchingRequestTransformations: + request = await currentTransformation["apply"](request) + + firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) + + if (firstResonseTransformation != None): + self.logger.info(f"Invoking response transformation") + response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) + self.logger.info(f"Received response transformation result with status code {response.status_code}") + return response + else: self.logger.info(f"Sending request to {request.url}") - - # Call the base class's send method to actually send the request - try: - response = await super().send(request, stream=stream, proxies=proxies, **config) - except Exception as e: - self.logger.error(f"Error: {e}") - raise - - # Add custom logic after receiving the response + response = await getResponseTask self.logger.info(f"Received response with status code {response.status_code}") - return response - class MockExecuteFunction(object): - def __init__(self, org_func, num_exceptions, status_code): - self.org_func = org_func - self.counter = 0 - self.num_exceptions = num_exceptions - self.status_code = status_code - - def __call__(self, func, global_endpoint_manager, *args, **kwargs): - if self.counter != 0 and self.counter >= self.num_exceptions: - return self.org_func(func, global_endpoint_manager, *args, **kwargs) - else: - self.counter += 1 - raise exceptions.CosmosHttpResponseError( - status_code=self.status_code, - message="Some Exception", - response=test_config.FakeResponse({})) - if __name__ == '__main__': unittest.main() From b53e2e9ecbd1cecfc163ffd7666759d885845603 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 13:34:33 +0000 Subject: [PATCH 011/152] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 38 ++++++++++++++-------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 4b2fda5d3437..f30a2fc7099f 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -46,6 +46,19 @@ def setup(): class TestDummyAsync: logger = logger + async def cleanup(self, initializedObjects: dict[str, Any]): + created_database: DatabaseProxy = initializedObjects["db"] + try: + await created_database.delete_container(initializedObjects["col"]) + except Exception as containerDeleteError: + self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + finally: + client: CosmosClient = initializedObjects["client"] + try: + await client.close() + except Exception as closeError: + self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + async def setup(self, custom_transport: AioHttpTransport): host = test_config.TestConfig.host @@ -77,19 +90,18 @@ async def test_throws_injected_error(self, setup): 'key': 'value'} initializedObjects = await self.setup(custom_transport) - container: ContainerProxy = initializedObjects["col"] - - created_document = await container.create_item(body=document_definition) - start = time.perf_counter() - - while ((time.perf_counter() - start) < 7): - await container.read_item(idValue, partition_key=idValue) - await asyncio.sleep(2) - - created_database: DatabaseProxy = initializedObjects["db"] - await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) - client: CosmosClient = initializedObjects["client"] - await client.close() + try: + container: ContainerProxy = initializedObjects["col"] + + created_document = await container.create_item(body=document_definition) + start = time.perf_counter() + + while ((time.perf_counter() - start) < 7): + await container.read_item(idValue, partition_key=idValue) + await asyncio.sleep(2) + + finally: + self.cleanup(initializedObjects) class FaulInjectionTransport(AioHttpTransport): From 973ec4412ae0fa847dedb86a96b4caea2ea34ae9 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 15:30:49 +0000 Subject: [PATCH 012/152] Iterating on fault injection tooling --- sdk/cosmos/azure-cosmos/pytest.ini | 3 + sdk/cosmos/azure-cosmos/test/test_dummy.py | 69 ++++++++++++++++++---- 2 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/pytest.ini diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini new file mode 100644 index 000000000000..e211052edef0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + cosmosEmulator: marks tests as depending in Cosmos DB Emulator \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index f30a2fc7099f..9a96c75eb598 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -14,7 +14,7 @@ from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy -import azure.cosmos.exceptions as exceptions +from azure.cosmos.exceptions import CosmosHttpResponseError import test_config from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient, _retry_utility_async @@ -44,20 +44,19 @@ def setup(): @pytest.mark.asyncio @pytest.mark.usefixtures("setup") class TestDummyAsync: - logger = logger async def cleanup(self, initializedObjects: dict[str, Any]): created_database: DatabaseProxy = initializedObjects["db"] try: await created_database.delete_container(initializedObjects["col"]) except Exception as containerDeleteError: - self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: client: CosmosClient = initializedObjects["client"] try: await client.close() except Exception as closeError: - self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) async def setup(self, custom_transport: AioHttpTransport): @@ -75,15 +74,62 @@ async def setup(self, custom_transport: AioHttpTransport): "tests.") client = CosmosClient(host, masterKey, consistency_level="Session", - connection_policy=connectionPolicy, transport=custom_transport) + connection_policy=connectionPolicy, transport=custom_transport, + logger=logger) created_database: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) created_collection = await created_database.create_container(TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk")) return {"client": client, "db": created_database, "col": created_collection} + def predicate_url_contains_id(self, r: HttpRequest, id: str): + logger.info("FaultPredicate for request {} {}".format(r.method, r.url)); + return id in r.url; + + def predicate_req_payload_contains_id(self, r: HttpRequest, id: str): + logger.info("FaultPredicate for request {} {} - request payload {}".format( + r.method, + r.url, + "NONE" if r.body is None else r.body)); + + if (r.body == None): + return False + + + return '"id":"{}"'.format(id) in r.body; + + async def throw_after_delay(self, delayInMs: int, error: Exception): + await asyncio.sleep(delayInMs/1000.0) + raise error + async def test_throws_injected_error(self, setup): + idValue: str = str(uuid.uuid4()) + document_definition = {'id': idValue, + 'pk': idValue, + 'name': 'sample document', + 'key': 'value'} + + custom_transport = FaulInjectionTransport(logger) + predicate : Callable[[HttpRequest], bool] = lambda r: self.predicate_req_payload_contains_id(r, idValue) + custom_transport.addFault(predicate, lambda: self.throw_after_delay( + 500, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error."))) + + initializedObjects = await self.setup(custom_transport) + try: + container: ContainerProxy = initializedObjects["col"] + await container.create_item(body=document_definition) + pytest.fail("Expected exception not thrown") + except CosmosHttpResponseError as cosmosError: + if (cosmosError.status_code != 502): + raise cosmosError + finally: + await self.cleanup(initializedObjects) + + async def test_succeeds_with_multiple_endpoints(self, setup): custom_transport = FaulInjectionTransport(logger) - idValue: str = 'failoverDoc-' + str(uuid.uuid4()) + idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, 'pk': idValue, 'name': 'sample document', @@ -96,12 +142,12 @@ async def test_throws_injected_error(self, setup): created_document = await container.create_item(body=document_definition) start = time.perf_counter() - while ((time.perf_counter() - start) < 7): + while ((time.perf_counter() - start) < 2): await container.read_item(idValue, partition_key=idValue) - await asyncio.sleep(2) + await asyncio.sleep(0.2) finally: - self.cleanup(initializedObjects) + await self.cleanup(initializedObjects) class FaulInjectionTransport(AioHttpTransport): @@ -135,8 +181,9 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut # find the first fault Factory with matching predicate if any firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) if (firstFaultFactory != None): - self.logger.info("") - return await firstFaultFactory["apply"]() + injectedError = await firstFaultFactory["apply"]() + self.logger.info("Found to-be-injected error {}".format(injectedError)) + raise injectedError # apply the chain of request transformations with matching predicates if any matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) From 5d72848fbf91c295007acbff2a7e110572ca07c2 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 16:31:23 +0000 Subject: [PATCH 013/152] Refactoring to have FaultInjectionTransport in its own file --- .../test/_fault_injection_transport.py | 89 +++++++++++++++++++ sdk/cosmos/azure-cosmos/test/test_dummy.py | 69 ++------------ 2 files changed, 94 insertions(+), 64 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py new file mode 100644 index 000000000000..aa137fa34234 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -0,0 +1,89 @@ +# The MIT License (MIT) +# Copyright (c) 2014 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""AioHttpTransport allowing injection of faults between SDK and Cosmos Gateway +""" + +import asyncio +import aiohttp +import logging +import sys + +from azure.core.pipeline.transport import AioHttpTransport +from azure.core.rest import HttpRequest, AsyncHttpResponse +from collections.abc import MutableMapping +from typing import Any, Callable + +class FaulInjectionTransport(AioHttpTransport): + def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + self.logger = logger + self.faults = [] + self.requestTransformations = [] + self.responseTransformations = [] + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): + self.faults.append({"predicate": predicate, "apply": faultFactory}) + + def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): + self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) + + def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): + self.responseTransformations.append({ + "predicate": predicate, + "apply": responseTransformation}) + + def firstItem(self, iterable, condition=lambda x: True): + """ + Returns the first item in the `iterable` that satisfies the `condition`. + + If no item satisfies the condition, it returns None. + """ + return next((x for x in iterable if condition(x)), None) + + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # find the first fault Factory with matching predicate if any + firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) + if (firstFaultFactory != None): + injectedError = await firstFaultFactory["apply"]() + self.logger.info("Found to-be-injected error {}".format(injectedError)) + raise injectedError + + # apply the chain of request transformations with matching predicates if any + matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matchingRequestTransformations: + request = await currentTransformation["apply"](request) + + firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) + + if (firstResonseTransformation != None): + self.logger.info(f"Invoking response transformation") + response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) + self.logger.info(f"Received response transformation result with status code {response.status_code}") + return response + else: + self.logger.info(f"Sending request to {request.url}") + response = await getResponseTask + self.logger.info(f"Received response with status code {response.status_code}") + return response + \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 9a96c75eb598..45408793b900 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -1,16 +1,12 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -from collections.abc import MutableMapping import logging import time -from tokenize import String -from typing import Any, Callable import unittest import uuid import pytest -import pytest_asyncio from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy @@ -18,11 +14,12 @@ import test_config from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient, _retry_utility_async -from azure.core.rest import HttpRequest, AsyncHttpResponse +from azure.core.rest import HttpRequest import asyncio -import aiohttp import sys from azure.core.pipeline.transport import AioHttpTransport +from typing import Any, Callable +import _fault_injection_transport COLLECTION = "created_collection" logger = logging.getLogger('azure.cosmos') @@ -108,7 +105,7 @@ async def test_throws_injected_error(self, setup): 'name': 'sample document', 'key': 'value'} - custom_transport = FaulInjectionTransport(logger) + custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) predicate : Callable[[HttpRequest], bool] = lambda r: self.predicate_req_payload_contains_id(r, idValue) custom_transport.addFault(predicate, lambda: self.throw_after_delay( 500, @@ -128,7 +125,7 @@ async def test_throws_injected_error(self, setup): await self.cleanup(initializedObjects) async def test_succeeds_with_multiple_endpoints(self, setup): - custom_transport = FaulInjectionTransport(logger) + custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, 'pk': idValue, @@ -149,61 +146,5 @@ async def test_succeeds_with_multiple_endpoints(self, setup): finally: await self.cleanup(initializedObjects) - -class FaulInjectionTransport(AioHttpTransport): - def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): - self.logger = logger - self.faults = [] - self.requestTransformations = [] - self.responseTransformations = [] - super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - - def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): - self.faults.append({"predicate": predicate, "apply": faultFactory}) - - def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): - self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) - - def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): - self.responseTransformations.append({ - "predicate": predicate, - "apply": responseTransformation}) - - def firstItem(self, iterable, condition=lambda x: True): - """ - Returns the first item in the `iterable` that satisfies the `condition`. - - If no item satisfies the condition, it returns None. - """ - return next((x for x in iterable if condition(x)), None) - - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): - # find the first fault Factory with matching predicate if any - firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) - if (firstFaultFactory != None): - injectedError = await firstFaultFactory["apply"]() - self.logger.info("Found to-be-injected error {}".format(injectedError)) - raise injectedError - - # apply the chain of request transformations with matching predicates if any - matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) - for currentTransformation in matchingRequestTransformations: - request = await currentTransformation["apply"](request) - - firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) - - getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) - - if (firstResonseTransformation != None): - self.logger.info(f"Invoking response transformation") - response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) - self.logger.info(f"Received response transformation result with status code {response.status_code}") - return response - else: - self.logger.info(f"Sending request to {request.url}") - response = await getResponseTask - self.logger.info(f"Received response with status code {response.status_code}") - return response - if __name__ == '__main__': unittest.main() From 8c9aa4b370afbc7596acd8303959d3aef0c486cd Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Mon, 10 Feb 2025 14:53:52 +0000 Subject: [PATCH 014/152] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 83 ++++++++++++++-------- 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 45408793b900..91c9e219350b 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -20,8 +20,10 @@ from azure.core.pipeline.transport import AioHttpTransport from typing import Any, Callable import _fault_injection_transport +import os COLLECTION = "created_collection" +MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') logger.setLevel("INFO") logger.setLevel(logging.DEBUG) @@ -36,47 +38,66 @@ def setup(): return - @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") class TestDummyAsync: - - async def cleanup(self, initializedObjects: dict[str, Any]): - created_database: DatabaseProxy = initializedObjects["db"] + @classmethod + def setup_class(cls): + logger.info("starting class: {} execution".format(cls.__name__)) + cls.host = test_config.TestConfig.host + cls.masterKey = test_config.TestConfig.masterKey + + if (cls.masterKey == '[YOUR_KEY_HERE]' or + cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + cls.connectionPolicy = test_config.TestConfig.connectionPolicy + cls.database_id = test_config.TestConfig.TEST_DATABASE_ID + cls.single_partition_container_name= os.path.basename(__file__) + str(uuid.uuid4()) + + cls.mgmtClient = CosmosClient(host, masterKey, consistency_level="Session", + connection_policy=connectionPolicy, logger=logger) + created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) + asyncio.run(asyncio.wait_for( + created_database.create_container( + cls.single_partition_container_name, + partition_key=PartitionKey("/pk")), + MGMT_TIMEOUT)) + + @classmethod + def teardown_class(cls): + logger.info("tearing down class: {}".format(cls.__name__)) + created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) try: - await created_database.delete_container(initializedObjects["col"]) + asyncio.run(asyncio.wait_for( + created_database.delete_container(cls.single_partition_container_name), + MGMT_TIMEOUT)) except Exception as containerDeleteError: logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: - client: CosmosClient = initializedObjects["client"] try: - await client.close() + asyncio.run(asyncio.wait_for(cls.mgmtClient.close(), MGMT_TIMEOUT)) except Exception as closeError: logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - async def setup(self, custom_transport: AioHttpTransport): - - host = test_config.TestConfig.host - masterKey = test_config.TestConfig.masterKey - connectionPolicy = test_config.TestConfig.connectionPolicy - TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) - - if (masterKey == '[YOUR_KEY_HERE]' or - host == '[YOUR_ENDPOINT_HERE]'): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") - + def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport): client = CosmosClient(host, masterKey, consistency_level="Session", connection_policy=connectionPolicy, transport=custom_transport, logger=logger) - created_database: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - created_collection = await created_database.create_container(TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk")) - return {"client": client, "db": created_database, "col": created_collection} + db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) + container: ContainerProxy = db.get_container_client(self.single_partition_container_name) + return {"client": client, "db": db, "col": container} + + def cleanup_method(self, initializedObjects: dict[str, Any]): + method_client: CosmosClient = initializedObjects["client"] + try: + asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) + except Exception as closeError: + logger.warning("Exception trying to close method client.") def predicate_url_contains_id(self, r: HttpRequest, id: str): logger.info("FaultPredicate for request {} {}".format(r.method, r.url)); @@ -113,7 +134,7 @@ async def test_throws_injected_error(self, setup): status_code=502, message="Some random reverse proxy error."))) - initializedObjects = await self.setup(custom_transport) + initializedObjects = self.setup_method_with_custom_transport(custom_transport) try: container: ContainerProxy = initializedObjects["col"] await container.create_item(body=document_definition) @@ -122,7 +143,9 @@ async def test_throws_injected_error(self, setup): if (cosmosError.status_code != 502): raise cosmosError finally: - await self.cleanup(initializedObjects) + cleanupOp = self.cleanup_method(initializedObjects) + if (cleanupOp != None): + await cleanupOp async def test_succeeds_with_multiple_endpoints(self, setup): custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) @@ -132,7 +155,7 @@ async def test_succeeds_with_multiple_endpoints(self, setup): 'name': 'sample document', 'key': 'value'} - initializedObjects = await self.setup(custom_transport) + initializedObjects = self.setup_method_with_custom_transport(custom_transport) try: container: ContainerProxy = initializedObjects["col"] @@ -144,7 +167,7 @@ async def test_succeeds_with_multiple_endpoints(self, setup): await asyncio.sleep(0.2) finally: - await self.cleanup(initializedObjects) + self.cleanup_method(initializedObjects) if __name__ == '__main__': unittest.main() From 7260e9d156ffc4321f0f3c13c89ddd136e8ea46a Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Tue, 18 Feb 2025 12:43:39 +0000 Subject: [PATCH 015/152] Reafctoring FaultInjectionTransport --- .../test/_fault_injection_transport.py | 42 ++++++++++++++++--- sdk/cosmos/azure-cosmos/test/test_dummy.py | 41 ++++++------------ 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index aa137fa34234..123658b1516c 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -25,14 +25,13 @@ import asyncio import aiohttp import logging -import sys from azure.core.pipeline.transport import AioHttpTransport from azure.core.rest import HttpRequest, AsyncHttpResponse from collections.abc import MutableMapping from typing import Any, Callable -class FaulInjectionTransport(AioHttpTransport): +class FaultInjectionTransport(AioHttpTransport): def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): self.logger = logger self.faults = [] @@ -46,7 +45,7 @@ def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Calla def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) - def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): self.responseTransformations.append({ "predicate": predicate, "apply": responseTransformation}) @@ -59,7 +58,7 @@ def firstItem(self, iterable, condition=lambda x: True): """ return next((x for x in iterable if condition(x)), None) - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: # find the first fault Factory with matching predicate if any firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) if (firstFaultFactory != None): @@ -74,7 +73,7 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) - getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) + getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) if (firstResonseTransformation != None): self.logger.info(f"Invoking response transformation") @@ -86,4 +85,35 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut response = await getResponseTask self.logger.info(f"Received response with status code {response.status_code}") return response - \ No newline at end of file + + @staticmethod + def predicate_url_contains_id(r: HttpRequest, id: str) -> bool: + return id in r.url + + @staticmethod + def predicate_req_payload_contains_id(r: HttpRequest, id: str): + if r.body is None: + return False + + return '"id":"{}"'.format(id) in r.body + + @staticmethod + def predicate_req_for_document_with_id(r: HttpRequest, id: str) -> bool: + return (FaultInjectionTransport.predicate_url_contains_id(r, id) + or FaultInjectionTransport.predicate_req_payload_contains_id(r, id)) + + @staticmethod + def predicate_is_database_account_call(r: HttpRequest) -> bool: + return (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') + + @staticmethod + async def throw_after_delay(delay_in_ms: int, error: Exception): + await asyncio.sleep(delay_in_ms / 1000.0) + raise error + + @staticmethod + async def transform_pass_through(r: HttpRequest, + inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + await asyncio.sleep(1) + return await inner() \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 91c9e219350b..91fa0e7283b0 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -14,12 +14,12 @@ import test_config from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient, _retry_utility_async -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, AsyncHttpResponse import asyncio import sys from azure.core.pipeline.transport import AioHttpTransport from typing import Any, Callable -import _fault_injection_transport +from _fault_injection_transport import FaultInjectionTransport import os COLLECTION = "created_collection" @@ -36,7 +36,8 @@ @pytest.fixture() def setup(): - return + return + @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -87,7 +88,7 @@ def teardown_class(cls): def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport): client = CosmosClient(host, masterKey, consistency_level="Session", connection_policy=connectionPolicy, transport=custom_transport, - logger=logger) + logger=logger, enable_diagnostics_logging=True) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) container: ContainerProxy = db.get_container_client(self.single_partition_container_name) return {"client": client, "db": db, "col": container} @@ -99,26 +100,6 @@ def cleanup_method(self, initializedObjects: dict[str, Any]): except Exception as closeError: logger.warning("Exception trying to close method client.") - def predicate_url_contains_id(self, r: HttpRequest, id: str): - logger.info("FaultPredicate for request {} {}".format(r.method, r.url)); - return id in r.url; - - def predicate_req_payload_contains_id(self, r: HttpRequest, id: str): - logger.info("FaultPredicate for request {} {} - request payload {}".format( - r.method, - r.url, - "NONE" if r.body is None else r.body)); - - if (r.body == None): - return False - - - return '"id":"{}"'.format(id) in r.body; - - async def throw_after_delay(self, delayInMs: int, error: Exception): - await asyncio.sleep(delayInMs/1000.0) - raise error - async def test_throws_injected_error(self, setup): idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, @@ -126,9 +107,9 @@ async def test_throws_injected_error(self, setup): 'name': 'sample document', 'key': 'value'} - custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) - predicate : Callable[[HttpRequest], bool] = lambda r: self.predicate_req_payload_contains_id(r, idValue) - custom_transport.addFault(predicate, lambda: self.throw_after_delay( + custom_transport = FaultInjectionTransport(logger) + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, idValue) + custom_transport.addFault(predicate, lambda: FaultInjectionTransport.throw_after_delay( 500, CosmosHttpResponseError( status_code=502, @@ -148,7 +129,11 @@ async def test_throws_injected_error(self, setup): await cleanupOp async def test_succeeds_with_multiple_endpoints(self, setup): - custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) + custom_transport = FaultInjectionTransport(logger) + predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = lambda r, inner: FaultInjectionTransport.transform_pass_through(r, inner) + custom_transport.add_response_transformation(predicate, transformation) + idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, 'pk': idValue, From 0705aeb76a367c34c4f7bdfdc09ff3f78753d7ea Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Wed, 19 Feb 2025 20:04:39 +0000 Subject: [PATCH 016/152] Iterating on tests --- .../azure/cosmos/aio/_asynchronous_request.py | 1 + .../aio/_global_endpoint_manager_async.py | 5 + .../test/_fault_injection_transport.py | 92 +++++++++++++++---- sdk/cosmos/azure-cosmos/test/test_dummy.py | 49 +++++----- 4 files changed, 105 insertions(+), 42 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 81430d8df42c..f8ebf6ccdbb8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -117,6 +117,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p response = response.http_response headers = copy.copy(response.headers) + await response.load_body() data = response.body() if data: data = data.decode("utf-8") diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 374fd940c184..365ef9c9b395 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -126,13 +126,18 @@ async def _endpoints_health_check(self, **kwargs): """ all_endpoints = [self.location_cache.read_regional_endpoints[0]] all_endpoints.extend(self.location_cache.write_regional_endpoints) + validated_endpoints = {} count = 0 for endpoint in all_endpoints: + if (endpoint.get_current() in validated_endpoints): + continue + count += 1 if count > 3: break try: await self.client._GetDatabaseAccountCheck(endpoint.get_current(), **kwargs) + validated_endpoints[endpoint.get_current()] = "" except (exceptions.CosmosHttpResponseError, AzureError): if endpoint in self.location_cache.read_regional_endpoints: self.mark_endpoint_unavailable_for_read(endpoint.get_current(), False) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index 123658b1516c..0dd75b3c0e60 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -23,17 +23,24 @@ """ import asyncio -import aiohttp +import json import logging +import sys +from collections.abc import MutableMapping +from typing import Callable +import aiohttp from azure.core.pipeline.transport import AioHttpTransport from azure.core.rest import HttpRequest, AsyncHttpResponse -from collections.abc import MutableMapping -from typing import Any, Callable + +from azure.cosmos.exceptions import CosmosHttpResponseError + class FaultInjectionTransport(AioHttpTransport): - def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): - self.logger = logger + logger = logging.getLogger('azure.cosmos.fault_injection_transport') + logger.setLevel(logging.DEBUG) + + def __init__(self, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): self.faults = [] self.requestTransformations = [] self.responseTransformations = [] @@ -59,37 +66,50 @@ def firstItem(self, iterable, condition=lambda x: True): return next((x for x in iterable if condition(x)), None) async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) if (firstFaultFactory != None): + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") injectedError = await firstFaultFactory["apply"]() - self.logger.info("Found to-be-injected error {}".format(injectedError)) + FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injectedError)) raise injectedError # apply the chain of request transformations with matching predicates if any matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) for currentTransformation in matchingRequestTransformations: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") request = await currentTransformation["apply"](request) - firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) - - getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) - - if (firstResonseTransformation != None): - self.logger.info(f"Invoking response transformation") - response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) - self.logger.info(f"Received response transformation result with status code {response.status_code}") + firstResponseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + getResponseTask = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) + FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + + if (firstResponseTransformation != None): + FaultInjectionTransport.logger.info(f"Invoking response transformation") + response = await firstResponseTransformation["apply"](request, lambda: getResponseTask) + FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") return response else: - self.logger.info(f"Sending request to {request.url}") + FaultInjectionTransport.logger.info(f"Sending request to {request.url}") response = await getResponseTask - self.logger.info(f"Received response with status code {response.status_code}") + FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") return response @staticmethod def predicate_url_contains_id(r: HttpRequest, id: str) -> bool: return id in r.url + @staticmethod + def print_call_stack(): + print("Call stack:") + frame = sys._getframe() + while frame: + print(f"File: {frame.f_code.co_filename}, Line: {frame.f_lineno}, Function: {frame.f_code.co_name}") + frame = frame.f_back + @staticmethod def predicate_req_payload_contains_id(r: HttpRequest, id: str): if r.body is None: @@ -104,16 +124,48 @@ def predicate_req_for_document_with_id(r: HttpRequest, id: str) -> bool: @staticmethod def predicate_is_database_account_call(r: HttpRequest) -> bool: - return (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + isDbAccountRead = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') + return isDbAccountRead + + @staticmethod + def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: + isWriteDocumentOperation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' + and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Read' + and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'ReadFeed' + and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Query') + + return isWriteDocumentOperation and uri_prefix in r.url + + @staticmethod async def throw_after_delay(delay_in_ms: int, error: Exception): await asyncio.sleep(delay_in_ms / 1000.0) raise error @staticmethod - async def transform_pass_through(r: HttpRequest, + async def throw_write_forbidden(): + raise CosmosHttpResponseError( + status_code=403, + message="Injected error disallowing writes in this region.", + response=None, + sub_status_code=3, + ) + + @staticmethod + async def transform_convert_emulator_to_single_master_read_multi_region_account(r: HttpRequest, inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: - await asyncio.sleep(1) - return await inner() \ No newline at end of file + + response = await inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + await response.load_body() + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + result["readableLocations"].append({"name": "East US", "databaseAccountEndpoint" : "https://localhost:8888/"}) + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + return response \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 91fa0e7283b0..1300717de2c3 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -1,31 +1,30 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio import logging +import os +import sys import time import unittest import uuid +from typing import Any, Callable import pytest +from azure.core.pipeline.transport import AioHttpTransport +from azure.core.rest import HttpRequest, AsyncHttpResponse +import test_config +from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError -import test_config -from azure.cosmos import PartitionKey -from azure.cosmos.aio import CosmosClient, _retry_utility_async -from azure.core.rest import HttpRequest, AsyncHttpResponse -import asyncio -import sys -from azure.core.pipeline.transport import AioHttpTransport -from typing import Any, Callable -from _fault_injection_transport import FaultInjectionTransport -import os COLLECTION = "created_collection" MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') -logger.setLevel("INFO") logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -107,7 +106,7 @@ async def test_throws_injected_error(self, setup): 'name': 'sample document', 'key': 'value'} - custom_transport = FaultInjectionTransport(logger) + custom_transport = FaultInjectionTransport() predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, idValue) custom_transport.addFault(predicate, lambda: FaultInjectionTransport.throw_after_delay( 500, @@ -129,10 +128,15 @@ async def test_throws_injected_error(self, setup): await cleanupOp async def test_succeeds_with_multiple_endpoints(self, setup): - custom_transport = FaultInjectionTransport(logger) - predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = lambda r, inner: FaultInjectionTransport.transform_pass_through(r, inner) - custom_transport.add_response_transformation(predicate, transformation) + custom_transport = FaultInjectionTransport() + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + is_write_operation_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, "https://localhost") + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account(r, inner) + + custom_transport.addFault(is_write_operation_predicate, lambda: FaultInjectionTransport.throw_write_forbidden()) + custom_transport.add_response_transformation(is_get_account_predicate, emulator_as_multi_region_sm_account_transformation) idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, @@ -143,16 +147,17 @@ async def test_succeeds_with_multiple_endpoints(self, setup): initializedObjects = self.setup_method_with_custom_transport(custom_transport) try: container: ContainerProxy = initializedObjects["col"] - + created_document = await container.create_item(body=document_definition) start = time.perf_counter() - - while ((time.perf_counter() - start) < 2): - await container.read_item(idValue, partition_key=idValue) - await asyncio.sleep(0.2) + + + #while ((time.perf_counter() - start) < 2): + # await container.read_item(idValue, partition_key=idValue) + # await asyncio.sleep(0.2) finally: - self.cleanup_method(initializedObjects) + self.cleanup_method(initializedObjects) if __name__ == '__main__': unittest.main() From baf7aea226895998c136fb0298c492ac4811f2a3 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Thu, 20 Feb 2025 16:55:42 +0000 Subject: [PATCH 017/152] Prettifying tests --- .../test/_fault_injection_transport.py | 119 +++++++++++------- ...> test_fault_injection_transport_async.py} | 116 +++++++++-------- 2 files changed, 144 insertions(+), 91 deletions(-) rename sdk/cosmos/azure-cosmos/test/{test_dummy.py => test_fault_injection_transport_async.py} (50%) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index 0dd75b3c0e60..efe58bae3032 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -27,10 +27,10 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable +from typing import Callable, Optional import aiohttp -from azure.core.pipeline.transport import AioHttpTransport +from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse from azure.cosmos.exceptions import CosmosHttpResponseError @@ -46,18 +46,16 @@ def __init__(self, *, session: aiohttp.ClientSession | None = None, loop=None, s self.responseTransformations = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): - self.faults.append({"predicate": predicate, "apply": faultFactory}) + def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], asyncio.Task[Exception]]): + self.faults.append({"predicate": predicate, "apply": fault_factory}) - def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): - self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) - - def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): self.responseTransformations.append({ "predicate": predicate, - "apply": responseTransformation}) + "apply": response_transformation}) - def firstItem(self, iterable, condition=lambda x: True): + @staticmethod + def __first_item(iterable, condition=lambda x: True): """ Returns the first item in the `iterable` that satisfies the `condition`. @@ -68,39 +66,41 @@ def firstItem(self, iterable, condition=lambda x: True): async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any - firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) - if (firstFaultFactory != None): + first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) + if first_fault_factory: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") - injectedError = await firstFaultFactory["apply"]() - FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injectedError)) - raise injectedError + injected_error = await first_fault_factory["apply"](request) + FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error)) + raise injected_error # apply the chain of request transformations with matching predicates if any - matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) - for currentTransformation in matchingRequestTransformations: + matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matching_request_transformations: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") request = await currentTransformation["apply"](request) - firstResponseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + first_response_transformation = FaultInjectionTransport.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") - getResponseTask = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) + get_response_task = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") - if (firstResponseTransformation != None): + if first_response_transformation: FaultInjectionTransport.logger.info(f"Invoking response transformation") - response = await firstResponseTransformation["apply"](request, lambda: getResponseTask) + response = await first_response_transformation["apply"](request, lambda: get_response_task) + response.headers["_request"] = request FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") return response else: FaultInjectionTransport.logger.info(f"Sending request to {request.url}") - response = await getResponseTask + response = await get_response_task + response.headers["_request"] = request FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") return response @staticmethod - def predicate_url_contains_id(r: HttpRequest, id: str) -> bool: - return id in r.url + def predicate_url_contains_id(r: HttpRequest, id_value: str) -> bool: + return id_value in r.url @staticmethod def print_call_stack(): @@ -111,42 +111,41 @@ def print_call_stack(): frame = frame.f_back @staticmethod - def predicate_req_payload_contains_id(r: HttpRequest, id: str): + def predicate_req_payload_contains_id(r: HttpRequest, id_value: str): if r.body is None: return False - return '"id":"{}"'.format(id) in r.body + return '"id":"{}"'.format(id_value) in r.body @staticmethod - def predicate_req_for_document_with_id(r: HttpRequest, id: str) -> bool: - return (FaultInjectionTransport.predicate_url_contains_id(r, id) - or FaultInjectionTransport.predicate_req_payload_contains_id(r, id)) + def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool: + return (FaultInjectionTransport.predicate_url_contains_id(r, id_value) + or FaultInjectionTransport.predicate_req_payload_contains_id(r, id_value)) @staticmethod def predicate_is_database_account_call(r: HttpRequest) -> bool: - isDbAccountRead = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + is_db_account_read = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') - return isDbAccountRead + return is_db_account_read @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: - isWriteDocumentOperation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' + is_write_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Read' and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'ReadFeed' and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Query') - return isWriteDocumentOperation and uri_prefix in r.url - + return is_write_document_operation and uri_prefix in r.url @staticmethod - async def throw_after_delay(delay_in_ms: int, error: Exception): + async def error_after_delay(delay_in_ms: int, error: Exception) -> Exception: await asyncio.sleep(delay_in_ms / 1000.0) - raise error + return error @staticmethod - async def throw_write_forbidden(): - raise CosmosHttpResponseError( + async def error_write_forbidden() -> Exception: + return CosmosHttpResponseError( status_code=403, message="Injected error disallowing writes in this region.", response=None, @@ -154,8 +153,11 @@ async def throw_write_forbidden(): ) @staticmethod - async def transform_convert_emulator_to_single_master_read_multi_region_account(r: HttpRequest, - inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + async def transform_convert_emulator_to_single_master_read_multi_region_account( + additional_region: str, + artificial_uri: str, + r: HttpRequest, + inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): @@ -166,6 +168,39 @@ async def transform_convert_emulator_to_single_master_read_multi_region_account( if response.status_code == 200 and data: data = data.decode("utf-8") result = json.loads(data) - result["readableLocations"].append({"name": "East US", "databaseAccountEndpoint" : "https://localhost:8888/"}) + result["readableLocations"].append({"name": additional_region, "databaseAccountEndpoint" : artificial_uri}) FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) - return response \ No newline at end of file + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + + class MockHttpResponse(AioHttpTransportResponse): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, any]]): + self.request: HttpRequest = request + # This is actually never None, and set by all implementations after the call to + # __init__ of this class. This class is also a legacy impl, so it's risky to change it + # for low benefits The new "rest" implementation does define correctly status_code + # as non-optional. + self.status_code: int = status_code + self.headers: MutableMapping[str, str] = {} + self.reason: Optional[str] = None + self.content_type: Optional[str] = None + self.block_size: int = 4096 # Default to same as R + self.content: Optional[dict[str, any]] = None + self.json_text: Optional[str] = None + self.bytes: Optional[bytes] = None + if content: + self.content:Optional[dict[str, any]] = content + self.json_text:Optional[str] = json.dumps(content) + self.bytes:bytes = self.json_text.encode("utf-8") + + + def body(self) -> bytes: + return self.bytes + + def text(self, encoding: Optional[str] = None) -> str: + return self.json_text + + async def load_body(self) -> None: + return \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py similarity index 50% rename from sdk/cosmos/azure-cosmos/test/test_dummy.py rename to sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index 1300717de2c3..cba2c1074ec4 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -29,8 +29,8 @@ logger.addHandler(logging.StreamHandler(sys.stdout)) host = test_config.TestConfig.host -masterKey = test_config.TestConfig.masterKey -connectionPolicy = test_config.TestConfig.connectionPolicy +master_key = test_config.TestConfig.masterKey +connection_policy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID @pytest.fixture() @@ -41,27 +41,27 @@ def setup(): @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") -class TestDummyAsync: +class TestFaultInjectionTransportAsync: @classmethod def setup_class(cls): logger.info("starting class: {} execution".format(cls.__name__)) cls.host = test_config.TestConfig.host - cls.masterKey = test_config.TestConfig.masterKey + cls.master_key = test_config.TestConfig.masterKey - if (cls.masterKey == '[YOUR_KEY_HERE]' or + if (cls.master_key == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.connectionPolicy = test_config.TestConfig.connectionPolicy + cls.connection_policy = test_config.TestConfig.connectionPolicy cls.database_id = test_config.TestConfig.TEST_DATABASE_ID cls.single_partition_container_name= os.path.basename(__file__) + str(uuid.uuid4()) - cls.mgmtClient = CosmosClient(host, masterKey, consistency_level="Session", - connection_policy=connectionPolicy, logger=logger) - created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) + cls.mgmt_client = CosmosClient(host, master_key, consistency_level="Session", + connection_policy=connection_policy, logger=logger) + created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( cls.single_partition_container_name, @@ -71,93 +71,111 @@ def setup_class(cls): @classmethod def teardown_class(cls): logger.info("tearing down class: {}".format(cls.__name__)) - created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) + created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) try: asyncio.run(asyncio.wait_for( created_database.delete_container(cls.single_partition_container_name), MGMT_TIMEOUT)) except Exception as containerDeleteError: - logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: try: - asyncio.run(asyncio.wait_for(cls.mgmtClient.close(), MGMT_TIMEOUT)) + asyncio.run(asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT)) except Exception as closeError: - logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport): - client = CosmosClient(host, masterKey, consistency_level="Session", - connection_policy=connectionPolicy, transport=custom_transport, - logger=logger, enable_diagnostics_logging=True) + def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, **kwargs): + client = CosmosClient(host, master_key, consistency_level="Session", + connection_policy=connection_policy, transport=custom_transport, + logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) container: ContainerProxy = db.get_container_client(self.single_partition_container_name) return {"client": client, "db": db, "col": container} - def cleanup_method(self, initializedObjects: dict[str, Any]): - method_client: CosmosClient = initializedObjects["client"] + @staticmethod + def cleanup_method(initialized_objects: dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] try: asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) - except Exception as closeError: - logger.warning("Exception trying to close method client.") + except Exception as close_error: + logger.warning(f"Exception trying to close method client. {close_error}") async def test_throws_injected_error(self, setup): - idValue: str = str(uuid.uuid4()) - document_definition = {'id': idValue, - 'pk': idValue, + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, 'name': 'sample document', 'key': 'value'} custom_transport = FaultInjectionTransport() - predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, idValue) - custom_transport.addFault(predicate, lambda: FaultInjectionTransport.throw_after_delay( + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( 500, CosmosHttpResponseError( status_code=502, - message="Some random reverse proxy error."))) + message="Some random reverse proxy error.")))) - initializedObjects = self.setup_method_with_custom_transport(custom_transport) + initialized_objects = self.setup_method_with_custom_transport(custom_transport) try: - container: ContainerProxy = initializedObjects["col"] + container: ContainerProxy = initialized_objects["col"] await container.create_item(body=document_definition) pytest.fail("Expected exception not thrown") except CosmosHttpResponseError as cosmosError: - if (cosmosError.status_code != 502): + if cosmosError.status_code != 502: raise cosmosError finally: - cleanupOp = self.cleanup_method(initializedObjects) - if (cleanupOp != None): - await cleanupOp + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_succeeds_with_multiple_endpoints(self, setup): + localhost_uri: str = test_config.TestConfig.local_host + alternate_localhost_uri: str = localhost_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) is_write_operation_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, "https://localhost") - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account(r, inner) + r: FaultInjectionTransport.predicate_is_write_operation(r, localhost_uri) - custom_transport.addFault(is_write_operation_predicate, lambda: FaultInjectionTransport.throw_write_forbidden()) - custom_transport.add_response_transformation(is_get_account_predicate, emulator_as_multi_region_sm_account_transformation) + # Emulator uses "South Central US" with Uri https://127.0.0.1:8888 - idValue: str = str(uuid.uuid4()) - document_definition = {'id': idValue, - 'pk': idValue, + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account( + additional_region="East US", + artificial_uri=localhost_uri, + r=r, + inner=inner) + + custom_transport.add_fault( + is_write_operation_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, 'name': 'sample document', 'key': 'value'} - initializedObjects = self.setup_method_with_custom_transport(custom_transport) + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["East US", "South Central US"]) try: - container: ContainerProxy = initializedObjects["col"] + container: ContainerProxy = initialized_objects["col"] created_document = await container.create_item(body=document_definition) - start = time.perf_counter() - + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(alternate_localhost_uri) + start:float = time.perf_counter() - #while ((time.perf_counter() - start) < 2): - # await container.read_item(idValue, partition_key=idValue) - # await asyncio.sleep(0.2) + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(localhost_uri) finally: - self.cleanup_method(initializedObjects) + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) if __name__ == '__main__': unittest.main() From e90b722d30e123bdf3ccf15c87f0923074354ddf Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 21 Feb 2025 18:30:44 +0000 Subject: [PATCH 018/152] small refactoring --- .../test/_fault_injection_transport.py | 13 +++++--- .../test_fault_injection_transport_async.py | 33 ++++++++++--------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index efe58bae3032..46ea8c2c6ce2 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -33,6 +33,7 @@ from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse +import test_config from azure.cosmos.exceptions import CosmosHttpResponseError @@ -153,9 +154,9 @@ async def error_write_forbidden() -> Exception: ) @staticmethod - async def transform_convert_emulator_to_single_master_read_multi_region_account( - additional_region: str, - artificial_uri: str, + async def transform_topology_swr_mrr( + write_region_name: str, + read_region_name: str, r: HttpRequest, inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: @@ -168,7 +169,11 @@ async def transform_convert_emulator_to_single_master_read_multi_region_account( if response.status_code == 200 and data: data = data.decode("utf-8") result = json.loads(data) - result["readableLocations"].append({"name": additional_region, "databaseAccountEndpoint" : artificial_uri}) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = write_region_name + writable_locations[0]["name"] = write_region_name + readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host}) FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request return FaultInjectionTransport.MockHttpResponse(request, 200, result) diff --git a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index cba2c1074ec4..320992fe221b 100644 --- a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -127,25 +127,26 @@ async def test_throws_injected_error(self, setup): TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_succeeds_with_multiple_endpoints(self, setup): - localhost_uri: str = test_config.TestConfig.local_host - alternate_localhost_uri: str = localhost_uri.replace("localhost", "127.0.0.1") + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - is_write_operation_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, localhost_uri) + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) - # Emulator uses "South Central US" with Uri https://127.0.0.1:8888 + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account( - additional_region="East US", - artificial_uri=localhost_uri, + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", r=r, inner=inner) - - custom_transport.add_fault( - is_write_operation_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) custom_transport.add_response_transformation( is_get_account_predicate, emulator_as_multi_region_sm_account_transformation) @@ -158,21 +159,21 @@ async def test_succeeds_with_multiple_endpoints(self, setup): initialized_objects = self.setup_method_with_custom_transport( custom_transport, - preferred_locations=["East US", "South Central US"]) + preferred_locations=["Read Region", "Write Region"]) try: container: ContainerProxy = initialized_objects["col"] created_document = await container.create_item(body=document_definition) request: HttpRequest = created_document.get_response_headers()["_request"] # Validate the response comes from "South Central US" (the write region) - assert request.url.startswith(alternate_localhost_uri) + assert request.url.startswith(expected_write_region_uri) start:float = time.perf_counter() while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) request: HttpRequest = read_document.get_response_headers()["_request"] # Validate the response comes from "East US" (the most preferred read-only region) - assert request.url.startswith(localhost_uri) + assert request.url.startswith(expected_read_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) From cb58896447b57b7ec76c0dfd24d51e570a33e854 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 21 Feb 2025 18:58:50 +0000 Subject: [PATCH 019/152] Adding MM topology on Emulator --- .../test/_fault_injection_transport.py | 31 ++++++++++++ .../test_fault_injection_transport_async.py | 47 ++++++++++++++++++- 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index 46ea8c2c6ce2..bd76d795452a 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -180,6 +180,37 @@ async def transform_topology_swr_mrr( return response + @staticmethod + async def transform_topology_mwr( + first_region_name: str, + second_region_name: str, + r: HttpRequest, + inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + + response = await inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + await response.load_body() + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = first_region_name + writable_locations[0]["name"] = first_region_name + readable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + writable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + result["enableMultipleWriteLocations"] = True + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + class MockHttpResponse(AioHttpTransportResponse): def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, any]]): self.request: HttpRequest = request diff --git a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index 320992fe221b..817c263cad34 100644 --- a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -126,7 +126,7 @@ async def test_throws_injected_error(self, setup): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_succeeds_with_multiple_endpoints(self, setup): + async def test_swr_mrr_succeeds(self, setup): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -178,5 +178,50 @@ async def test_succeeds_with_multiple_endpoints(self, setup): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + async def test_mwr_succeeds(self, setup): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + r=r, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + created_document = await container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(first_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(first_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + if __name__ == '__main__': unittest.main() From 46ec31ca1eef20f815c658fb51d5410cb3a60390 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Sat, 22 Feb 2025 00:05:07 +0000 Subject: [PATCH 020/152] Adding cross region retry tests --- .../test/_fault_injection_transport.py | 18 ++- .../test_fault_injection_transport_async.py | 136 +++++++++++++++++- 2 files changed, 149 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index bd76d795452a..37c00667544a 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -35,7 +35,7 @@ import test_config from azure.cosmos.exceptions import CosmosHttpResponseError - +from azure.core.exceptions import ServiceRequestError class FaultInjectionTransport(AioHttpTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') @@ -103,6 +103,10 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut def predicate_url_contains_id(r: HttpRequest, id_value: str) -> bool: return id_value in r.url + @staticmethod + def predicate_targets_region(r: HttpRequest, region_endpoint: str) -> bool: + return r.url.startswith(region_endpoint) + @staticmethod def print_call_stack(): print("Call stack:") @@ -130,6 +134,12 @@ def predicate_is_database_account_call(r: HttpRequest) -> bool: return is_db_account_read + @staticmethod + def predicate_is_document_operation(r: HttpRequest) -> bool: + is_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs') + + return is_document_operation + @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: is_write_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' @@ -153,6 +163,12 @@ async def error_write_forbidden() -> Exception: sub_status_code=3, ) + @staticmethod + async def error_region_down() -> Exception: + return ServiceRequestError( + message="Injected region down.", + ) + @staticmethod async def transform_topology_swr_mrr( write_region_name: str, diff --git a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index 817c263cad34..d09f017febbf 100644 --- a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -84,8 +84,8 @@ def teardown_class(cls): except Exception as closeError: logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, **kwargs): - client = CosmosClient(host, master_key, consistency_level="Session", + def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, master_key, consistency_level="Session", connection_policy=connection_policy, transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) @@ -165,19 +165,147 @@ async def test_swr_mrr_succeeds(self, setup): created_document = await container.create_item(body=document_definition) request: HttpRequest = created_document.get_response_headers()["_request"] - # Validate the response comes from "South Central US" (the write region) + # Validate the response comes from "Write Region" (the write region) assert request.url.startswith(expected_write_region_uri) start:float = time.perf_counter() while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) request: HttpRequest = read_document.get_response_headers()["_request"] - # Validate the response comes from "East US" (the most preferred read-only region) + # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_read_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + async def test_swr_mrr_region_down_read_succeeds(self, setup): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: asyncio.create_task(FaultInjectionTransport.error_region_down())) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + r=r, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + created_document = await container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( + 35000, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error.")))) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + r=r, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + created_document = await container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + async def test_mwr_succeeds(self, setup): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host From f03f51f3f6035a8a73b864bb41debd18f97cf6df Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 31 Mar 2025 12:25:32 -0700 Subject: [PATCH 021/152] Add Excluded Locations Feature --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 1 + .../azure/cosmos/_cosmos_client_connection.py | 10 ++ .../azure/cosmos/_global_endpoint_manager.py | 4 +- .../azure/cosmos/_location_cache.py | 101 +++++++++--- .../azure/cosmos/_request_object.py | 25 ++- .../azure/cosmos/aio/_container.py | 30 ++++ .../aio/_cosmos_client_connection_async.py | 10 ++ .../aio/_global_endpoint_manager_async.py | 6 +- .../azure-cosmos/azure/cosmos/container.py | 30 ++++ .../azure/cosmos/cosmos_client.py | 4 + .../azure-cosmos/azure/cosmos/documents.py | 8 + .../samples/excluded_locations.py | 110 +++++++++++++ .../azure-cosmos/tests/test_health_check.py | 6 +- .../tests/test_health_check_async.py | 12 +- .../azure-cosmos/tests/test_location_cache.py | 148 +++++++++++++++++- .../tests/test_retry_policy_async.py | 1 + 16 files changed, 462 insertions(+), 44 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/samples/excluded_locations.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 654b23c5d71f..bcfc611456ec 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -63,6 +63,7 @@ 'priority': 'priorityLevel', 'no_response': 'responsePayloadOnWriteDisabled', 'max_item_count': 'maxItemCount', + 'excluded_locations': 'excludedLocations', } # Cosmos resource ID validation regex breakdown: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3934c23bcf99..d64da38defb1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2044,6 +2044,7 @@ def PatchItem( documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2132,6 +2133,7 @@ def _Batch( headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch) + request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2192,6 +2194,7 @@ def DeleteAllItemsByPartitionKey( headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) request_params = RequestObject("partitionkey", documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2647,6 +2650,7 @@ def Create( # Create will use WriteEndpoint since it uses POST operation request_params = RequestObject(typ, documents._OperationType.Create) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2693,6 +2697,7 @@ def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2736,6 +2741,7 @@ def Replace( options) # Replace will use WriteEndpoint since it uses PUT operation request_params = RequestObject(typ, documents._OperationType.Replace) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2777,6 +2783,7 @@ def Read( headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation request_params = RequestObject(typ, documents._OperationType.Read) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2816,6 +2823,7 @@ def DeleteResource( options) # Delete will use WriteEndpoint since it uses DELETE operation request_params = RequestObject(typ, documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3052,6 +3060,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: resource_type, documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed ) + request_params.set_excluded_location_from_options(options) headers = base.GetHeaders( self, initial_headers, @@ -3090,6 +3099,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) request_params = RequestObject(resource_type, documents._OperationType.SqlQuery) + request_params.set_excluded_location_from_options(options) req_headers = base.GetHeaders( self, initial_headers, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index e167871dd4a5..944b684e392b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -50,10 +50,8 @@ def __init__(self, client): self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( - self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy ) self.refresh_needed = False self.refresh_lock = threading.RLock() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 96651d5c8b7f..02b293e29b4b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -25,12 +25,13 @@ import collections import logging import time -from typing import Set +from typing import Set, Mapping, List from urllib.parse import urlparse from . import documents from . import http_constants from .documents import _OperationType +from ._request_object import RequestObject # pylint: disable=protected-access @@ -113,7 +114,10 @@ def get_endpoints_by_location(new_locations, except Exception as e: raise e - return endpoints_by_location, parsed_locations + # Also store a hash map of endpoints for each location + locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()} + + return endpoints_by_location, locations_by_endpoints, parsed_locations def add_endpoint_if_preferred(endpoint: str, preferred_endpoints: Set[str], endpoints: Set[str]) -> bool: if endpoint in preferred_endpoints: @@ -150,6 +154,21 @@ def _get_health_check_endpoints( return endpoints +def _get_applicable_regional_endpoints(endpoints: List[RegionalRoutingContext], + location_name_by_endpoint: Mapping[str, str], + fall_back_endpoint: RegionalRoutingContext, + exclude_location_list: List[str]) -> List[RegionalRoutingContext]: + # filter endpoints by excluded locations + applicable_endpoints = [] + for endpoint in endpoints: + if location_name_by_endpoint.get(endpoint.get_primary()) not in exclude_location_list: + applicable_endpoints.append(endpoint) + + # if endpoint is empty add fallback endpoint + if not applicable_endpoints: + applicable_endpoints.append(fall_back_endpoint) + + return applicable_endpoints class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes def current_time_millis(self): @@ -157,15 +176,10 @@ def current_time_millis(self): def __init__( self, - preferred_locations, default_endpoint, - enable_endpoint_discovery, - use_multiple_write_locations, + connection_policy, ): - self.preferred_locations = preferred_locations self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint) - self.enable_endpoint_discovery = enable_endpoint_discovery - self.use_multiple_write_locations = use_multiple_write_locations self.enable_multiple_writable_locations = False self.write_regional_routing_contexts = [self.default_regional_routing_context] self.read_regional_routing_contexts = [self.default_regional_routing_context] @@ -173,8 +187,11 @@ def __init__( self.last_cache_update_time_stamp = 0 self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long + self.account_locations_by_read_regional_routing_context = {} # pylint: disable=name-too-long + self.account_locations_by_write_regional_routing_context = {} # pylint: disable=name-too-long self.account_write_locations = [] self.account_read_locations = [] + self.connection_policy = connection_policy def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -207,6 +224,44 @@ def get_ordered_write_locations(self): def get_ordered_read_locations(self): return self.account_read_locations + def _get_configured_excluded_locations(self, request: RequestObject): + # If excluded locations were configured on request, use request level excluded locations. + excluded_locations = request.excluded_locations + if excluded_locations is None: + # If excluded locations were only configured on client(connection_policy), use client level + excluded_locations = self.connection_policy.ExcludedLocations + return excluded_locations + + def _get_applicable_read_regional_endpoints(self, request: RequestObject): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return _get_applicable_regional_endpoints( + self.get_read_regional_routing_contexts(), + self.account_locations_by_read_regional_routing_context, + self.get_write_regional_routing_contexts()[0], + excluded_locations) + + # Else, return all regional endpoints + return self.get_read_regional_routing_contexts() + + def _get_applicable_write_regional_endpoints(self, request: RequestObject): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return _get_applicable_regional_endpoints( + self.get_write_regional_routing_contexts(), + self.account_locations_by_write_regional_routing_context, + self.default_regional_routing_context, + excluded_locations) + + # Else, return all regional endpoints + return self.get_write_regional_routing_contexts() + def resolve_service_endpoint(self, request): if request.location_endpoint_to_route: return request.location_endpoint_to_route @@ -227,7 +282,7 @@ def resolve_service_endpoint(self, request): # For non-document resource types in case of client can use multiple write locations # or when client cannot use multiple write locations, flip-flop between the # first and the second writable region in DatabaseAccount (for manual failover) - if self.enable_endpoint_discovery and self.account_write_locations: + if self.connection_policy.EnableEndpointDiscovery and self.account_write_locations: location_index = min(location_index % 2, len(self.account_write_locations) - 1) write_location = self.account_write_locations[location_index] if (self.account_write_regional_routing_contexts_by_location @@ -247,9 +302,9 @@ def resolve_service_endpoint(self, request): return self.default_regional_routing_context.get_primary() regional_routing_contexts = ( - self.get_write_regional_routing_contexts() + self._get_applicable_write_regional_endpoints(request) if documents._OperationType.IsWriteOperation(request.operation_type) - else self.get_read_regional_routing_contexts() + else self._get_applicable_read_regional_endpoints(request) ) regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)] if ( @@ -263,12 +318,14 @@ def resolve_service_endpoint(self, request): return regional_routing_context.get_primary() def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements - most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None + most_preferred_location = self.connection_policy.PreferredLocations[0] \ + if self.connection_policy.PreferredLocations else None # we should schedule refresh in background if we are unable to target the user's most preferredLocation. - if self.enable_endpoint_discovery: + if self.connection_policy.EnableEndpointDiscovery: - should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations + should_refresh = (self.connection_policy.UseMultipleWriteLocations + and not self.enable_multiple_writable_locations) if (most_preferred_location and most_preferred_location in self.account_read_regional_routing_contexts_by_location): @@ -358,25 +415,27 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl if enable_multiple_writable_locations: self.enable_multiple_writable_locations = enable_multiple_writable_locations - if self.enable_endpoint_discovery: + if self.connection_policy.EnableEndpointDiscovery: if read_locations: (self.account_read_regional_routing_contexts_by_location, + self.account_locations_by_read_regional_routing_context, self.account_read_locations) = get_endpoints_by_location( read_locations, self.account_read_regional_routing_contexts_by_location, self.default_regional_routing_context, False, - self.use_multiple_write_locations + self.connection_policy.UseMultipleWriteLocations ) if write_locations: (self.account_write_regional_routing_contexts_by_location, + self.account_locations_by_write_regional_routing_context, self.account_write_locations) = get_endpoints_by_location( write_locations, self.account_write_regional_routing_contexts_by_location, self.default_regional_routing_context, True, - self.use_multiple_write_locations + self.connection_policy.UseMultipleWriteLocations ) self.write_regional_routing_contexts = self.get_preferred_regional_routing_contexts( @@ -399,18 +458,18 @@ def get_preferred_regional_routing_contexts( regional_endpoints = [] # if enableEndpointDiscovery is false, we always use the defaultEndpoint that # user passed in during documentClient init - if self.enable_endpoint_discovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks + if self.connection_policy.EnableEndpointDiscovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks if ( self.can_use_multiple_write_locations() or expected_available_operation == EndpointOperationType.ReadType ): unavailable_endpoints = [] - if self.preferred_locations: + if self.connection_policy.PreferredLocations: # When client can not use multiple write locations, preferred locations # list should only be used determining read endpoints order. If client # can use multiple write locations, preferred locations list should be # used for determining both read and write endpoints order. - for location in self.preferred_locations: + for location in self.connection_policy.PreferredLocations: regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \ else None if regional_endpoint: @@ -436,7 +495,7 @@ def get_preferred_regional_routing_contexts( return regional_endpoints def can_use_multiple_write_locations(self): - return self.use_multiple_write_locations and self.enable_multiple_writable_locations + return self.connection_policy.UseMultipleWriteLocations and self.enable_multiple_writable_locations def can_use_multiple_write_locations_for_request(self, request): # pylint: disable=name-too-long return self.can_use_multiple_write_locations() and ( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index a220c6af42c2..94805934ce74 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,7 +21,8 @@ """Represents a request object. """ -from typing import Optional +from typing import Optional, Mapping, Any + class RequestObject(object): def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: @@ -33,6 +34,7 @@ def __init__(self, resource_type: str, operation_type: str, endpoint_override: O self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None + self.excluded_locations = None def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -52,3 +54,24 @@ def clear_route_to_location(self) -> None: self.location_index_to_route = None self.use_preferred_locations = None self.location_endpoint_to_route = None + + def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: + # If resource types for requests are not one of the followings, excluded locations cannot be set + if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey']: + return False + + # If 'excludedLocations' wasn't in the options, excluded locations cannot be set + if (options is None + or 'excludedLocations' not in options): + return False + + # The 'excludedLocations' cannot be None + if options['excludedLocations'] is None: + raise ValueError("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + + return True + + def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None: + if self._can_set_excluded_location(options): + self.excluded_locations = options['excludedLocations'] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 0142e215f318..590f43331652 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -224,6 +224,8 @@ async def create_item( :keyword bool enable_automatic_id_generation: Enable automatic id generation if no id present. :keyword str session_token: Token for use with Session consistency. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -303,6 +305,8 @@ async def read_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :returns: A CosmosDict representing the retrieved item. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -361,6 +365,8 @@ def read_all_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] """ @@ -441,6 +447,8 @@ def query_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] @@ -537,6 +545,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -575,6 +585,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -601,6 +613,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -639,6 +653,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -675,6 +691,8 @@ def query_items_change_feed( # pylint: disable=unused-argument ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -748,6 +766,8 @@ async def upsert_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. @@ -830,6 +850,8 @@ async def replace_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -906,6 +928,8 @@ async def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty if @@ -973,6 +997,8 @@ async def delete_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1223,6 +1249,8 @@ async def delete_all_items_by_partition_key( :keyword str pre_trigger_include: trigger id to be used as pre operation trigger. :keyword str post_trigger_include: trigger id to be used as post operation trigger. :keyword str session_token: Token for use with Session consistency. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :rtype: None """ @@ -1278,6 +1306,8 @@ async def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: A CosmosList representing the items after the batch operations went through. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The batch failed to execute. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 49219533a7e6..9008b46bb1c1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -768,6 +768,7 @@ async def Create( # Create will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -907,6 +908,7 @@ async def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1208,6 +1210,7 @@ async def Read( options) # Read will use ReadEndpoint since it uses GET operation request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1466,6 +1469,7 @@ async def PatchItem( documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1570,6 +1574,7 @@ async def Replace( options) # Replace will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1893,6 +1898,7 @@ async def DeleteResource( options) # Delete will use WriteEndpoint since it uses DELETE operation request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2006,6 +2012,7 @@ async def _Batch( headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) + request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2861,6 +2868,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: typ, documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed ) + request_params.set_excluded_location_from_options(options) headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, options, partition_key_range_id) @@ -2890,6 +2898,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) + request_params.set_excluded_location_from_options(options) req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, options, partition_key_range_id) @@ -3259,6 +3268,7 @@ async def DeleteAllItemsByPartitionKey( headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 4d00a7ef5629..f576e97d8e0b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,7 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from asyncio import CancelledError +from asyncio import CancelledError # pylint: disable=do-not-import-asyncio from typing import Tuple from azure.core.exceptions import AzureError @@ -53,10 +53,8 @@ def __init__(self, client): self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( - self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy ) self.startup = True self.refresh_task = None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index efa5e7c09a50..a815c9110471 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -233,6 +233,8 @@ def read_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: A CosmosDict representing the item to be retrieved. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -298,6 +300,8 @@ def read_all_items( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: Iterable[Dict[str, Any]] """ @@ -364,6 +368,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -403,6 +409,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -429,6 +437,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -466,6 +476,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -501,6 +513,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :param Any args: args @@ -601,6 +615,8 @@ def query_items( # pylint:disable=docstring-missing-param :keyword bool populate_index_metrics: Used to obtain the index metrics to understand how the query engine used existing indexes and how it could use potential new indexes. Please note that this options will incur overhead, so it should be enabled only when debugging slow queries. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: ItemPaged[Dict[str, Any]] @@ -716,6 +732,8 @@ def replace_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -790,6 +808,8 @@ def upsert_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -879,6 +899,8 @@ def create_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Item with the given ID already exists. :returns: A CosmosDict representing the new item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -970,6 +992,8 @@ def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty @@ -1030,6 +1054,8 @@ def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: [Callable[[Mapping[str, str], List[Dict[str, Any]]], None] :returns: A CosmosList representing the items after the batch operations went through. @@ -1102,6 +1128,8 @@ def delete_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1377,6 +1405,8 @@ def delete_all_items_by_partition_key( :keyword str pre_trigger_include: trigger id to be used as pre operation trigger. :keyword str post_trigger_include: trigger id to be used as post operation trigger. :keyword str session_token: Token for use with Session consistency. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] = None, :rtype: None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py index 10543f97c47b..b7a6ea94bd2b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py @@ -93,6 +93,8 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + # TODO: Consider storing callback method instead, such as 'Supplier' in JAVA SDK + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config @@ -181,6 +183,8 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 40fbed24451f..9e04829be52f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -308,6 +308,13 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes locations in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US', 'Central India' and so on. :vartype PreferredLocations: List[str] + :ivar ExcludedLocations: + Gets or sets the excluded locations for geo-replicated database + accounts. When ExcludedLocations is non-empty, the client will skip this + set of locations from the final location evaluation. The locations in + this list are specified as the names of the azure Cosmos locations like, + 'West US', 'East US', 'Central India' and so on. + :vartype ExcludedLocations: ~CosmosExcludedLocations :ivar RetryOptions: Gets or sets the retry options to be applied to all requests when retrying. @@ -347,6 +354,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations: List[str] = [] self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False diff --git a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py new file mode 100644 index 000000000000..06228c1a8cea --- /dev/null +++ b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.cosmos import CosmosClient +from azure.cosmos.partition_key import PartitionKey +import config + +# ---------------------------------------------------------------------------------------------------------- +# Prerequisites - +# +# 1. An Azure Cosmos account - +# https://learn.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account +# +# 2. Microsoft Azure Cosmos +# pip install azure-cosmos>=4.3.0b4 +# ---------------------------------------------------------------------------------------------------------- +# Sample - demonstrates how to use excluded locations in client level and request level +# ---------------------------------------------------------------------------------------------------------- +# Note: +# This sample creates a Container to your database account. +# Each time a Container is created the account will be billed for 1 hour of usage based on +# the provisioned throughput (RU/s) of that account. +# ---------------------------------------------------------------------------------------------------------- + +HOST = config.settings["host"] +MASTER_KEY = config.settings["master_key"] + +TENANT_ID = config.settings["tenant_id"] +CLIENT_ID = config.settings["client_id"] +CLIENT_SECRET = config.settings["client_secret"] + +DATABASE_ID = config.settings["database_id"] +CONTAINER_ID = config.settings["container_id"] +PARTITION_KEY = PartitionKey(path="/id") + + +def get_test_item(num): + test_item = { + 'id': 'Item_' + str(num), + 'test_object': True, + 'lastName': 'Smith' + } + return test_item + +def clean_up_db(client): + try: + client.delete_database(DATABASE_ID) + except Exception as e: + pass + +def excluded_locations_client_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations = ['West US 3', 'West US'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # In our sample, ['West US 3', 'West US', 'East US 2'] - ['West US 3', 'West US'] => ['East US 2'], + # therefore 'East US 2' will be the read endpoint, and items will be read from 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0') + + clean_up_db(client) + +def excluded_locations_request_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations_on_client = ['West US 3', 'West US'] + excluded_locations_on_request = ['West US 3'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations_on_client + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # However, in our sample, since the excluded_locations` were passed with the read request, the `excluded_location` + # will be replaced with the locations from request, ['West US 3']. The `excluded_locations` on request always takes + # the highest priority! + # With the excluded_locations on request, the read endpoints will be ['West US', 'East US 2'] + # ['West US 3', 'West US', 'East US 2'] - ['West US 3'] => ['West US', 'East US 2'] + # Therefore, items will be read from 'West US' or 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0', excluded_locations=excluded_locations_on_request) + + clean_up_db(client) + +if __name__ == "__main__": + # excluded_locations_client_level_sample() + excluded_locations_request_level_sample() diff --git a/sdk/cosmos/azure-cosmos/tests/test_health_check.py b/sdk/cosmos/azure-cosmos/tests/test_health_check.py index 0d313e6c911c..75db9deacc41 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_health_check.py +++ b/sdk/cosmos/azure-cosmos/tests/test_health_check.py @@ -126,14 +126,14 @@ def test_health_check_timeouts_on_unavailable_endpoints(self, setup): locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestHealthCheck.host, REGION_1) setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.mark_endpoint_unavailable_for_read( locational_endpoint, True) - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = REGIONS + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = REGIONS try: setup[COLLECTION].create_item(body={'id': 'item' + str(uuid.uuid4()), 'pk': 'pk'}) finally: _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations class MockGetDatabaseAccountCheck(object): def __init__(self, client_connection=None, endpoint_unavailable=False): diff --git a/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py b/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py index ae2bf13fd8a7..a92eca0dd778 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py @@ -153,8 +153,8 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g # checks the background health check works as expected when all endpoints healthy self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub self.original_getDatabaseAccountCheck = _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = preferred_location + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = preferred_location mock_get_database_account_check = self.MockGetDatabaseAccountCheck() _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = ( self.MockGetDatabaseAccount(REGIONS, use_write_global_endpoint, use_read_global_endpoint)) @@ -168,7 +168,7 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g finally: _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations expected_regional_routing_contexts = [] locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) @@ -189,8 +189,8 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = ( self.MockGetDatabaseAccount(REGIONS, use_write_global_endpoint, use_read_global_endpoint)) - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = preferred_location + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = preferred_location try: setup[COLLECTION].client_connection._global_endpoint_manager.startup = False @@ -201,7 +201,7 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g await asyncio.sleep(1) finally: _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations if not use_write_global_endpoint: num_unavailable_endpoints = len(REGIONS) diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index a957094f1790..f65a1f1a3d21 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -3,8 +3,10 @@ import time import unittest +from typing import Mapping, Any import pytest +from azure.cosmos import documents from azure.cosmos.documents import DatabaseAccount, _OperationType from azure.cosmos.http_constants import ResourceType @@ -35,15 +37,15 @@ def create_database_account(enable_multiple_writable_locations): return db_acc -def refresh_location_cache(preferred_locations, use_multiple_write_locations): - lc = LocationCache(preferred_locations=preferred_locations, - default_endpoint=default_endpoint, - enable_endpoint_discovery=True, - use_multiple_write_locations=use_multiple_write_locations) +def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=documents.ConnectionPolicy()): + connection_policy.PreferredLocations = preferred_locations + connection_policy.UseMultipleWriteLocations = use_multiple_write_locations + lc = LocationCache(default_endpoint=default_endpoint, + connection_policy=connection_policy) return lc @pytest.mark.cosmosEmulator -class TestLocationCache(unittest.TestCase): +class TestLocationCache: def test_mark_endpoint_unavailable(self): lc = refresh_location_cache([], False) @@ -136,6 +138,140 @@ def test_resolve_request_endpoint_preferred_regions(self): assert read_resolved == write_resolved assert read_resolved == default_endpoint + @pytest.mark.parametrize("test_type",["OnClient", "OnRequest", "OnBoth"]) + def test_get_applicable_regional_endpoints_excluded_regions(self, test_type): + # Init test data + if test_type == "OnClient": + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + excluded_locations_on_requests_list = [None] * 5 + elif test_type == "OnRequest": + excluded_locations_on_client_list = [[]] * 5 + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + else: + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name, location3_name], + [location1_name, location2_name], + [location2_name], + [location1_name, location2_name, location3_name], + ] + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + + expected_read_endpoints_list = [ + [location2_endpoint], + [location1_endpoint], + [location1_endpoint], + [location1_endpoint, location2_endpoint], + [location1_endpoint, location2_endpoint], + ] + expected_write_endpoints_list = [ + [location2_endpoint, location3_endpoint], + [location3_endpoint], + [default_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + ] + + # Loop over each test cases + for excluded_locations_on_client, excluded_locations_on_requests, expected_read_endpoints, expected_write_endpoints in zip(excluded_locations_on_client_list, excluded_locations_on_requests_list, expected_read_endpoints_list, expected_write_endpoints_list): + # Init excluded_locations in ConnectionPolicy + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Init requests and set excluded regions on requests + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + write_doc_request.excluded_locations = excluded_locations_on_requests + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + read_doc_request.excluded_locations = excluded_locations_on_requests + + # Test if read endpoints were correctly filtered on client level + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + # Test if write endpoints were correctly filtered on client level + write_doc_endpoint = location_cache._get_applicable_write_regional_endpoints(write_doc_request) + write_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in write_doc_endpoint] + assert write_doc_endpoint == expected_write_endpoints + + def test_set_excluded_locations_for_requests(self): + # Init excluded_locations in ConnectionPolicy + excluded_locations_on_client = [location1_name, location2_name] + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Test setting excluded locations + excluded_locations = [location1_name] + options: Mapping[str, Any] = {"excludedLocations": excluded_locations} + + expected_excluded_locations = excluded_locations + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location2_endpoint] + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + # Test setting excluded locations with invalid resource types + expected_excluded_locations = None + for resource_type in [ResourceType.Offer, ResourceType.Conflict]: + options: Mapping[str, Any] = {"excludedLocations": [location1_name]} + read_doc_request = RequestObject(resource_type, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location1_endpoint] + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + + # Test setting excluded locations with None value + expected_error_message = ("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + with pytest.raises(ValueError) as e: + options: Mapping[str, Any] = {"excludedLocations": None} + doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + doc_request.set_excluded_location_from_options(options) + assert str( + e.value) == expected_error_message + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py index be1683d1504d..4faef31c9495 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py @@ -42,6 +42,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations = None self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False From cf42098ac85b35a818ffddb774bb5f404738b70e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 31 Mar 2025 15:47:16 -0700 Subject: [PATCH 022/152] initial ppcb changes --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 1 + .../azure-cosmos/azure/cosmos/_constants.py | 20 +++ .../azure/cosmos/_cosmos_client_connection.py | 52 ++++--- .../aio/execution_dispatcher.py | 6 +- .../execution_dispatcher.py | 11 +- .../azure/cosmos/_global_endpoint_manager.py | 11 +- .../azure/cosmos/_location_cache.py | 77 +++++++++- .../azure/cosmos/_request_object.py | 39 ++++- .../azure/cosmos/_routing/routing_range.py | 22 +++ .../cosmos/_service_request_retry_policy.py | 13 +- .../cosmos/_timeout_failover_retry_policy.py | 6 +- .../azure/cosmos/aio/_container.py | 32 ++++ .../aio/_cosmos_client_connection_async.py | 58 ++++--- .../aio/_global_endpoint_manager_async.py | 21 ++- .../azure/cosmos/aio/_retry_utility_async.py | 8 + .../azure-cosmos/azure/cosmos/container.py | 32 ++++ .../azure/cosmos/cosmos_client.py | 4 + .../azure-cosmos/azure/cosmos/documents.py | 8 + .../azure-cosmos/tests/test_location_cache.py | 143 +++++++++++++++++- .../tests/test_query_hybrid_search.py | 13 ++ .../tests/test_query_hybrid_search_async.py | 14 +- .../tests/test_query_vector_similarity.py | 19 ++- .../test_query_vector_similarity_async.py | 19 ++- 24 files changed, 550 insertions(+), 80 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 0a4906c4cefe..c5f836ce2a03 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.10.0b3 (Unreleased) #### Features Added +* Per partition circuit breaker support. It can be enabled through environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 654b23c5d71f..bcfc611456ec 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -63,6 +63,7 @@ 'priority': 'priorityLevel', 'no_response': 'responsePayloadOnWriteDisabled', 'max_item_count': 'maxItemCount', + 'excluded_locations': 'excludedLocations', } # Cosmos resource ID validation regex breakdown: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 890a172ccee6..2af40c74d77a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -44,6 +44,26 @@ class _Constants: # ServiceDocument Resource EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations" + # Environment variables + NON_STREAMING_ORDER_BY_DISABLED_CONFIG: str = "AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY" + NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT: str = "False" + HS_MAX_ITEMS_CONFIG: str = "AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS" + HS_MAX_ITEMS_CONFIG_DEFAULT: int = 1000 + MAX_ITEM_BUFFER_VS_CONFIG: str = "AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH" + MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000 + CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" + CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" + # Only applicable when circuit breaker is enabled ------------------------- + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10 + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT: int = 5 + FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED" + FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 70 + STALE_PARTITION_UNAVAILABILITY_CHECK = "AZURE_COSMOS_STALE_PARTITION_UNAVAILABILITY_CHECK_IN_SECONDS" + STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT: int = 120 + # ------------------------------------------------------------------------- + # Error code translations ERROR_TRANSLATIONS: Dict[int, str] = { 400: "BAD_REQUEST - Request being sent is invalid.", diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3934c23bcf99..298d3032877b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2043,7 +2043,8 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(resource_type, documents._OperationType.Patch) + request_params = RequestObject(resource_type, documents._OperationType.Patch, headers) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2131,7 +2132,8 @@ def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - request_params = RequestObject("docs", documents._OperationType.Batch) + request_params = RequestObject("docs", documents._OperationType.Batch, headers) + request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2191,7 +2193,8 @@ def DeleteAllItemsByPartitionKey( collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) - request_params = RequestObject("partitionkey", documents._OperationType.Delete) + request_params = RequestObject("partitionkey", documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2362,7 +2365,8 @@ def ExecuteStoredProcedure( documents._OperationType.ExecuteJavaScript, options) # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation - request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) + request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript, headers) + request_params.set_excluded_location_from_options(options) result, self.last_response_headers = self.__Post(path, request_params, params, headers, **kwargs) return result @@ -2558,7 +2562,7 @@ def GetDatabaseAccount( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = RequestObject("databaseaccount", documents._OperationType.Read, headers, url_connection) result, last_response_headers = self.__Get("", request_params, headers, **kwargs) self.last_response_headers = last_response_headers database_account = DatabaseAccount() @@ -2607,7 +2611,7 @@ def _GetDatabaseAccountCheck( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = RequestObject("databaseaccount", documents._OperationType.Read, headers, url_connection) self.__Get("", request_params, headers, **kwargs) @@ -2646,7 +2650,8 @@ def Create( options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create) + request_params = RequestObject(typ, documents._OperationType.Create, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2692,7 +2697,8 @@ def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params = RequestObject(typ, documents._OperationType.Upsert, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2735,7 +2741,8 @@ def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace) + request_params = RequestObject(typ, documents._OperationType.Replace, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2776,7 +2783,8 @@ def Read( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = RequestObject(typ, documents._OperationType.Read) + request_params = RequestObject(typ, documents._OperationType.Read, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2815,7 +2823,8 @@ def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = RequestObject(typ, documents._OperationType.Delete) + request_params = RequestObject(typ, documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3047,11 +3056,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - request_params = RequestObject( - resource_type, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed - ) headers = base.GetHeaders( self, initial_headers, @@ -3059,11 +3065,18 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: path, resource_id, resource_type, - request_params.operation_type, + op_typ, options, partition_key_range_id ) + request_params = RequestObject( + resource_type, + op_typ, + headers + ) + request_params.set_excluded_location_from_options(options) + change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: change_feed_state.populate_request_headers(self._routing_map_provider, headers) @@ -3089,7 +3102,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = RequestObject(resource_type, documents._OperationType.SqlQuery) req_headers = base.GetHeaders( self, initial_headers, @@ -3102,6 +3114,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: partition_key_range_id ) + request_params = RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) + # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) if isPrefixPartitionQuery: @@ -3183,7 +3198,8 @@ def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs: documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): + if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, + Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 70df36d6d015..9b4d32a4598b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -34,6 +34,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import StatusCodes +from ..._constants import _Constants as Constants # pylint: disable=protected-access @@ -117,8 +118,9 @@ async def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): - raise ValueError("Executing a vector search query with more items than the max is not allowed." + + if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, + Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): + raise ValueError("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator =\ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 1155e537c68c..453a4dc38d25 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -26,6 +26,7 @@ import json import os from azure.cosmos.exceptions import CosmosHttpResponseError +from .._constants import _Constants as Constants from azure.cosmos._execution_context import endpoint_component, multi_execution_aggregator from azure.cosmos._execution_context import non_streaming_order_by_aggregator, hybrid_search_aggregator from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase @@ -56,8 +57,9 @@ def _verify_valid_hybrid_search_query(hybrid_search_query_info): raise ValueError("Executing a hybrid search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your hybrid search query.") - if hybrid_search_query_info['take'] > os.environ.get('AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS', 1000): - raise ValueError("Executing a hybrid search query with more items than the max is not allowed." + + if hybrid_search_query_info['take'] > int(os.environ.get(Constants.HS_MAX_ITEMS_CONFIG, + Constants.HS_MAX_ITEMS_CONFIG_DEFAULT)): + raise ValueError("Executing a hybrid search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") @@ -149,8 +151,9 @@ def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): - raise ValueError("Executing a vector search query with more items than the max is not allowed." + + if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, + Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): + raise ValueError("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator = \ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index e167871dd4a5..8aa7c388a9b6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -31,7 +31,7 @@ from . import _constants as constants from . import exceptions from .documents import DatabaseAccount -from ._location_cache import LocationCache +from ._location_cache import LocationCache, current_time_millis # pylint: disable=protected-access @@ -53,7 +53,8 @@ def __init__(self, client): self.PreferredLocations, self.DefaultEndpoint, self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy.UseMultipleWriteLocations, + client.connection_policy ) self.refresh_needed = False self.refresh_lock = threading.RLock() @@ -98,7 +99,7 @@ def update_location_cache(self): self.location_cache.update_location_cache() def refresh_endpoint_list(self, database_account, **kwargs): - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: with self.refresh_lock: @@ -114,11 +115,11 @@ def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() # this will perform getDatabaseAccount calls to check endpoint health self._endpoints_health_check(**kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 96651d5c8b7f..d40a99f7c69f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -113,7 +113,10 @@ def get_endpoints_by_location(new_locations, except Exception as e: raise e - return endpoints_by_location, parsed_locations + # Also store a hash map of endpoints for each location + locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()} + + return endpoints_by_location, locations_by_endpoints, parsed_locations def add_endpoint_if_preferred(endpoint: str, preferred_endpoints: Set[str], endpoints: Set[str]) -> bool: if endpoint in preferred_endpoints: @@ -150,10 +153,24 @@ def _get_health_check_endpoints( return endpoints +def get_applicable_regional_endpoints(endpoints, location_name_by_endpoint, fall_back_endpoint, + exclude_location_list): + # filter endpoints by excluded locations + applicable_endpoints = [] + for endpoint in endpoints: + if location_name_by_endpoint.get(endpoint.get_primary()) not in exclude_location_list: + applicable_endpoints.append(endpoint) + + # if endpoint is empty add fallback endpoint + if not applicable_endpoints: + applicable_endpoints.append(fall_back_endpoint) + + return applicable_endpoints + +def current_time_millis(): + return int(round(time.time() * 1000)) class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes - def current_time_millis(self): - return int(round(time.time() * 1000)) def __init__( self, @@ -161,6 +178,7 @@ def __init__( default_endpoint, enable_endpoint_discovery, use_multiple_write_locations, + connection_policy, ): self.preferred_locations = preferred_locations self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint) @@ -173,8 +191,11 @@ def __init__( self.last_cache_update_time_stamp = 0 self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long + self.account_locations_by_read_regional_endpoints = {} # pylint: disable=name-too-long + self.account_locations_by_write_regional_endpoints = {} # pylint: disable=name-too-long self.account_write_locations = [] self.account_read_locations = [] + self.connection_policy = connection_policy def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -182,6 +203,10 @@ def get_write_regional_routing_contexts(self): def get_read_regional_routing_contexts(self): return self.read_regional_routing_contexts + def get_location_from_endpoint(self, endpoint: str) -> str: + regional_routing_context = RegionalRoutingContext(endpoint, endpoint) + return self.account_locations_by_read_regional_endpoints[regional_routing_context] + def get_write_regional_routing_context(self): return self.get_write_regional_routing_contexts()[0].get_primary() @@ -207,6 +232,45 @@ def get_ordered_write_locations(self): def get_ordered_read_locations(self): return self.account_read_locations + # Todo: @tvaron3 client should be appeneded to if using circuitbreaker exclude regions + def _get_configured_excluded_locations(self, request): + # If excluded locations were configured on request, use request level excluded locations. + excluded_locations = request.excluded_locations + if excluded_locations is None: + # If excluded locations were only configured on client(connection_policy), use client level + excluded_locations = self.connection_policy.ExcludedLocations + return excluded_locations + + def get_applicable_read_regional_endpoints(self, request): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return get_applicable_regional_endpoints( + self.get_read_regional_routing_contexts(), + self.account_locations_by_read_regional_endpoints, + self.get_write_regional_routing_contexts()[0], + excluded_locations) + + # Else, return all regional endpoints + return self.get_read_regional_routing_contexts() + + def get_applicable_write_regional_endpoints(self, request): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return get_applicable_regional_endpoints( + self.get_write_regional_routing_contexts(), + self.account_locations_by_write_regional_endpoints, + self.default_regional_routing_context, + excluded_locations) + + # Else, return all regional endpoints + return self.get_write_regional_routing_contexts() + def resolve_service_endpoint(self, request): if request.location_endpoint_to_route: return request.location_endpoint_to_route @@ -247,9 +311,9 @@ def resolve_service_endpoint(self, request): return self.default_regional_routing_context.get_primary() regional_routing_contexts = ( - self.get_write_regional_routing_contexts() + self.get_applicable_write_regional_endpoints(request) if documents._OperationType.IsWriteOperation(request.operation_type) - else self.get_read_regional_routing_contexts() + else self.get_applicable_read_regional_endpoints(request) ) regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)] if ( @@ -361,6 +425,7 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl if self.enable_endpoint_discovery: if read_locations: (self.account_read_regional_routing_contexts_by_location, + self.account_locations_by_read_regional_endpoints, self.account_read_locations) = get_endpoints_by_location( read_locations, self.account_read_regional_routing_contexts_by_location, @@ -371,6 +436,7 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl if write_locations: (self.account_write_regional_routing_contexts_by_location, + self.account_locations_by_write_regional_endpoints, self.account_write_locations) = get_endpoints_by_location( write_locations, self.account_write_regional_routing_contexts_by_location, @@ -391,7 +457,6 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl EndpointOperationType.ReadType, self.write_regional_routing_contexts[0] ) - self.last_cache_update_timestamp = self.current_time_millis() # pylint: disable=attribute-defined-outside-init def get_preferred_regional_routing_contexts( self, endpoints_by_location, orderedLocations, expected_available_operation, fallback_endpoint diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index a220c6af42c2..afc9fa4d30a9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,18 +21,27 @@ """Represents a request object. """ -from typing import Optional +from typing import Optional, Mapping, Any, List, Dict + class RequestObject(object): - def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: + def __init__( + self, + resource_type: str, + operation_type: str, + headers: Dict[str, Any], + endpoint_override: Optional[str] = None, + ) -> None: self.resource_type = resource_type self.operation_type = operation_type self.endpoint_override = endpoint_override self.should_clear_session_token_on_session_read_failure: bool = False # pylint: disable=name-too-long + self.headers = headers self.use_preferred_locations: Optional[bool] = None self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None + self.excluded_locations: Optional[List[str]] = None def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -52,3 +61,29 @@ def clear_route_to_location(self) -> None: self.location_index_to_route = None self.use_preferred_locations = None self.location_endpoint_to_route = None + + def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: + # If resource types for requests are one of the followings, excluded locations cannot be set + if self.resource_type.lower() in ['offers', 'conflicts']: + return False + + # If 'excludedLocations' wasn't in the options, excluded locations cannot be set + if (options is None + or 'excludedLocations' not in options): + return False + + # The 'excludedLocations' cannot be None + if options['excludedLocations'] is None: + raise ValueError("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + + return True + + def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None: + if self._can_set_excluded_location(options): + self.excluded_locations = options['excludedLocations'] + + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: + if self.excluded_locations: + self.excluded_locations.extend(excluded_locations) + self.excluded_locations = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index 452bc32e5b34..4e3d603ef0d8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -226,3 +226,25 @@ def is_subset(self, parent_range: 'Range') -> bool: normalized_child_range = self.to_normalized_range() return (normalized_parent_range.min <= normalized_child_range.min and normalized_parent_range.max >= normalized_child_range.max) + +class PartitionKeyRangeWrapper(object): + """Internal class for a representation of a unique partition for an account + """ + + def __init__(self, partition_key_range: Range, collection_rid: str) -> None: + self.partition_key_range = partition_key_range + self.collection_rid = collection_rid + + + def __str__(self) -> str: + return ( + f"PartitionKeyRangeWrapper(" + f"partition_key_range={self.partition_key_range}, " + f"collection_rid={self.collection_rid}, " + ) + + def __eq__(self, other): + if not isinstance(other, PartitionKeyRangeWrapper): + return False + return self.partition_key_range == other.partition_key_range and self.collection_rid == other.collection_rid + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 9b34d048e3a6..8630714bc6f3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -44,9 +44,12 @@ def ShouldRetry(self): if self.request.resource_type == ResourceType.DatabaseAccount: return False - refresh_cache = self.request.last_routed_location_endpoint_within_region is not None - # This logic is for the last retry and mark the region unavailable - self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) + if self.global_endpoint_manager.is_circuit_breaker_applicable(self.request): + self.global_endpoint_manager.mark_partition_unavailable(self.request) + else: + refresh_cache = self.request.last_routed_location_endpoint_within_region is not None + # This logic is for the last retry and mark the region unavailable + self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) # Check if it is safe to do another retry if self.in_region_retry_count >= self.total_in_region_retries: @@ -65,7 +68,7 @@ def ShouldRetry(self): self.failover_retry_count += 1 if self.failover_retry_count >= self.total_retries: return False - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region location_endpoint = self.resolve_next_region_service_endpoint() else: location_endpoint = self.resolve_current_region_service_endpoint() @@ -80,7 +83,7 @@ def ShouldRetry(self): # and we reset the in region retry count self.in_region_retry_count = 0 self.failover_retry_count += 1 - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region if self.failover_retry_count >= self.total_retries: return False location_endpoint = self.resolve_next_region_service_endpoint() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index b170fb4fd9d2..505a8edf3d06 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -5,14 +5,13 @@ Cosmos database service. """ from azure.cosmos.documents import _OperationType +from azure.cosmos.http_constants import HttpHeaders class _TimeoutFailoverRetryPolicy(object): def __init__(self, connection_policy, global_endpoint_manager, *args): self.retry_after_in_milliseconds = 500 - self.args = args - self.global_endpoint_manager = global_endpoint_manager # If an account only has 1 region, then we still want to retry once on the same region self._max_retry_attempt_count = (len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) @@ -28,6 +27,9 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ + # record the failure for circuit breaker tracking + self.global_endpoint_manager.record_failure(self.request) + # we don't retry on write operations for timeouts or any internal server errors if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): return False diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 59bf8ee71ba3..d5e5f967be4d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -166,6 +166,8 @@ async def read( :keyword bool populate_quota_info: Enable returning collection storage quota information in response headers. :keyword str session_token: Token for use with Session consistency. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Dict[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -227,6 +229,8 @@ async def create_item( has changed, and act according to the condition specified by the `match_condition` parameter. :keyword match_condition: The match condition to use upon the etag. :paramtype match_condition: ~azure.core.MatchConditions + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Dict[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -297,6 +301,8 @@ async def read_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :returns: A CosmosDict representing the retrieved item. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -359,6 +365,8 @@ def read_all_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] """ @@ -443,6 +451,8 @@ def query_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] @@ -541,6 +551,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] @@ -578,6 +590,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], AsyncItemPaged[Dict[str, Any]]], None] :returns: An AsyncItemPaged of items (dicts). @@ -604,6 +618,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], AsyncItemPaged[Dict[str, Any]]], None] :returns: An AsyncItemPaged of items (dicts). @@ -642,6 +658,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], AsyncItemPaged[Dict[str, Any]]], None] :returns: An AsyncItemPaged of items (dicts). @@ -678,6 +696,8 @@ def query_items_change_feed( # pylint: disable=unused-argument ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], Mapping[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -760,6 +780,8 @@ async def upsert_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. @@ -833,6 +855,8 @@ async def replace_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -908,6 +932,8 @@ async def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty if @@ -975,6 +1001,8 @@ async def delete_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Dict[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1230,6 +1258,8 @@ async def delete_all_items_by_partition_key( :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. :keyword ~azure.core.MatchConditions match_condition: The match condition to use upon the etag. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :rtype: None """ @@ -1281,6 +1311,8 @@ async def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: A CosmosList representing the items after the batch operations went through. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The batch failed to execute. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 49219533a7e6..0d3d61e87fa0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -54,6 +54,7 @@ from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreaker from .._routing import routing_range from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants @@ -63,7 +64,6 @@ from .. import _runtime_constants as runtime_constants from .. import _request_object from . import _asynchronous_request as asynchronous_request -from . import _global_endpoint_manager_async as global_endpoint_manager_async from .._routing.aio.routing_map_provider import SmartRoutingMapProvider from ._retry_utility_async import _ConnectionRetryPolicy from .. import _session @@ -169,7 +169,7 @@ def __init__( # pylint: disable=too-many-statements # Keeps the latest response headers from the server. self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager_async._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, AsyncHTTPPolicy): @@ -415,7 +415,8 @@ async def GetDatabaseAccount( documents._OperationType.Read, {}, client_id=self.client_id) # path # id # type - request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, + headers, url_connection) result, self.last_response_headers = await self.__Get("", request_params, headers, **kwargs) database_account = documents.DatabaseAccount() @@ -465,7 +466,9 @@ async def _GetDatabaseAccountCheck( documents._OperationType.Read, {}, client_id=self.client_id) # path # id # type - request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, + headers, + url_connection) await self.__Get("", request_params, headers, **kwargs) async def CreateDatabase( @@ -729,7 +732,9 @@ async def ExecuteStoredProcedure( documents._OperationType.ExecuteJavaScript, options) # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) + request_params = _request_object.RequestObject("sprocs", + documents._OperationType.ExecuteJavaScript, headers) + request_params.set_excluded_location_from_options(options) result, self.last_response_headers = await self.__Post(path, request_params, params, headers, **kwargs) return result @@ -767,7 +772,8 @@ async def Create( documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -906,7 +912,8 @@ async def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1207,7 +1214,8 @@ async def Read( headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1465,7 +1473,8 @@ async def PatchItem( headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, typ, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1569,7 +1578,8 @@ async def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1892,7 +1902,8 @@ async def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2005,7 +2016,8 @@ async def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) + request_params = _request_object.RequestObject("docs", documents._OperationType.Batch, headers) + request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2856,13 +2868,16 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) + headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, op_typ, + options, partition_key_range_id) request_params = _request_object.RequestObject( typ, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed + op_typ, + headers ) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) + request_params.set_excluded_location_from_options(options) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -2889,9 +2904,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, + documents._OperationType.SqlQuery, options, partition_key_range_id) + request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) # check if query has prefix partition key cont_prop = kwargs.pop("containerProperties", None) @@ -3195,7 +3212,8 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kw documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): + if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, + Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + @@ -3258,7 +3276,9 @@ async def DeleteAllItemsByPartitionKey( initial_headers = dict(self.default_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) - request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) + request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete, + headers) + request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 7c845841b224..ab07f90c411d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,6 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from asyncio import CancelledError from typing import Tuple from azure.core.exceptions import AzureError @@ -33,8 +32,7 @@ from .. import _constants as constants from .. import exceptions -from .._location_cache import LocationCache - +from .._location_cache import LocationCache, current_time_millis # pylint: disable=protected-access @@ -48,15 +46,15 @@ class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attrib def __init__(self, client): self.client = client - self.EnableEndpointDiscovery = client.connection_policy.EnableEndpointDiscovery self.PreferredLocations = client.connection_policy.PreferredLocations self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy.EnableEndpointDiscovery, + client.connection_policy.UseMultipleWriteLocations, + client.connection_policy ) self.startup = True self.refresh_task = None @@ -65,6 +63,7 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None + # TODO: @tvaron3 fix this def get_use_multiple_write_locations(self): return self.location_cache.can_use_multiple_write_locations() @@ -105,9 +104,9 @@ async def refresh_endpoint_list(self, database_account, **kwargs): try: await self.refresh_task self.refresh_task = None - except (Exception, CancelledError) as exception: #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) as exception: #pylint: disable=broad-exception-caught logger.exception("Health check task failed: %s", exception) - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: async with self.refresh_lock: @@ -123,11 +122,11 @@ async def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account and not self.startup: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() if not self.startup: # this will perform getDatabaseAccount calls to check endpoint health # in background @@ -216,5 +215,5 @@ async def close(self): self.refresh_task.cancel() try: await self.refresh_task - except (Exception, CancelledError) : #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) : #pylint: disable=broad-exception-caught pass diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index ef5bad070014..c020a5d4e31c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -104,6 +104,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) + global_endpoint_manager.record_success(request) if not client.last_response_headers: client.last_response_headers = {} @@ -198,6 +199,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: + global_endpoint_manager.record_failure(request) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -279,6 +281,8 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if not _has_database_account_header(request.http_request.headers): + # TODO: @tvaron3 record failure here + request.context.global_endpoint_manager.record_failure(request.context.options['request']) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -296,6 +300,8 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): + # TODO: @tvaron3 record failure here + request.context.global_endpoint_manager.record_failure(request.context.options['request']) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -310,6 +316,8 @@ async def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): + # TODO: @tvaron3 record failure here ? Not sure + request.context.global_endpoint_manager.record_failure(request.context.options['request']) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 69be9491f27b..42c3d31acae6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -168,6 +168,8 @@ def read( # pylint:disable=docstring-missing-param request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Raised if the container couldn't be retrieved. This includes if the container does not exist. @@ -225,6 +227,8 @@ def read_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: A CosmosDict representing the item to be retrieved. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -292,6 +296,8 @@ def read_all_items( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: Iterable[Dict[str, Any]] """ @@ -359,6 +365,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. Note that due to the nature of combining calls to build the results, this function may be called with a either single dict or iterable of dicts :type response_hook: @@ -400,6 +408,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :returns: An Iterable of items (dicts). @@ -426,6 +436,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :returns: An Iterable of items (dicts). @@ -463,6 +475,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :returns: An Iterable of items (dicts). @@ -498,6 +512,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :param Any args: args @@ -604,6 +620,8 @@ def query_items( # pylint:disable=docstring-missing-param :keyword bool populate_index_metrics: Used to obtain the index metrics to understand how the query engine used existing indexes and how it could use potential new indexes. Please note that this options will incur overhead, so it should be enabled only when debugging slow queries. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: ItemPaged[Dict[str, Any]] @@ -719,6 +737,8 @@ def replace_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -794,6 +814,8 @@ def upsert_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -876,6 +898,8 @@ def create_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Item with the given ID already exists. :returns: A CosmosDict representing the new item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -954,6 +978,8 @@ def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty @@ -1016,6 +1042,8 @@ def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: A CosmosList representing the items after the batch operations went through. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The batch failed to execute. @@ -1075,6 +1103,8 @@ def delete_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. :raises ~azure.cosmos.exceptions.CosmosResourceNotFoundError: The item does not exist in the container. @@ -1351,6 +1381,8 @@ def delete_all_items_by_partition_key( :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. :keyword ~azure.core.MatchConditions match_condition: The match condition to use upon the etag. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :rtype: None """ diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py index cc4bd43d13a2..3c4399ad6d60 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py @@ -94,6 +94,8 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + # TODO: Consider storing callback method instead, such as 'Supplier' in JAVA SDK + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config @@ -182,6 +184,8 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 942e426934b7..b5e18f3680e6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -308,6 +308,13 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes locations in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US', 'Central India' and so on. :vartype PreferredLocations: List[str] + :ivar ExcludedLocations: + Gets or sets the excluded locations for geo-replicated database + accounts. When ExcludedLocations is non-empty, the client will skip this + set of locations from the final location evaluation. The locations in + this list are specified as the names of the azure Cosmos locations like, + 'West US', 'East US', 'Central India' and so on. + :vartype ExcludedLocations: ~CosmosExcludedLocations :ivar RetryOptions: Gets or sets the retry options to be applied to all requests when retrying. @@ -347,6 +354,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations: List[str] = [] self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index a957094f1790..f540797bc112 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -3,8 +3,10 @@ import time import unittest +from typing import Mapping, Any import pytest +from azure.cosmos import documents from azure.cosmos.documents import DatabaseAccount, _OperationType from azure.cosmos.http_constants import ResourceType @@ -35,15 +37,16 @@ def create_database_account(enable_multiple_writable_locations): return db_acc -def refresh_location_cache(preferred_locations, use_multiple_write_locations): +def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=documents.ConnectionPolicy()): lc = LocationCache(preferred_locations=preferred_locations, default_endpoint=default_endpoint, enable_endpoint_discovery=True, - use_multiple_write_locations=use_multiple_write_locations) + use_multiple_write_locations=use_multiple_write_locations, + connection_policy=connection_policy) return lc @pytest.mark.cosmosEmulator -class TestLocationCache(unittest.TestCase): +class TestLocationCache: def test_mark_endpoint_unavailable(self): lc = refresh_location_cache([], False) @@ -136,6 +139,140 @@ def test_resolve_request_endpoint_preferred_regions(self): assert read_resolved == write_resolved assert read_resolved == default_endpoint + @pytest.mark.parametrize("test_type",["OnClient", "OnRequest", "OnBoth"]) + def test_get_applicable_regional_endpoints_excluded_regions(self, test_type): + # Init test data + if test_type == "OnClient": + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + excluded_locations_on_requests_list = [None] * 5 + elif test_type == "OnRequest": + excluded_locations_on_client_list = [[]] * 5 + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + else: + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name, location3_name], + [location1_name, location2_name], + [location2_name], + [location1_name, location2_name, location3_name], + ] + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + + expected_read_endpoints_list = [ + [location2_endpoint], + [location1_endpoint], + [location1_endpoint], + [location1_endpoint, location2_endpoint], + [location1_endpoint, location2_endpoint], + ] + expected_write_endpoints_list = [ + [location2_endpoint, location3_endpoint], + [location3_endpoint], + [default_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + ] + + # Loop over each test cases + for excluded_locations_on_client, excluded_locations_on_requests, expected_read_endpoints, expected_write_endpoints in zip(excluded_locations_on_client_list, excluded_locations_on_requests_list, expected_read_endpoints_list, expected_write_endpoints_list): + # Init excluded_locations in ConnectionPolicy + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Init requests and set excluded regions on requests + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + write_doc_request.excluded_locations = excluded_locations_on_requests + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + read_doc_request.excluded_locations = excluded_locations_on_requests + + # Test if read endpoints were correctly filtered on client level + read_doc_endpoint = location_cache.get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + # Test if write endpoints were correctly filtered on client level + write_doc_endpoint = location_cache.get_applicable_write_regional_endpoints(write_doc_request) + write_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in write_doc_endpoint] + assert write_doc_endpoint == expected_write_endpoints + + def test_set_excluded_locations_for_requests(self): + # Init excluded_locations in ConnectionPolicy + excluded_locations_on_client = [location1_name, location2_name] + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Test setting excluded locations + excluded_locations = [location1_name] + options: Mapping[str, Any] = {"excludedLocations": excluded_locations} + + expected_excluded_locations = excluded_locations + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location2_endpoint] + read_doc_endpoint = location_cache.get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + # Test setting excluded locations with invalid resource types + expected_excluded_locations = None + for resource_type in [ResourceType.Offer, ResourceType.Conflict]: + options: Mapping[str, Any] = {"excludedLocations": [location1_name]} + read_doc_request = RequestObject(resource_type, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location1_endpoint] + read_doc_endpoint = location_cache.get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + + # Test setting excluded locations with None value + expected_error_message = ("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + with pytest.raises(ValueError) as e: + options: Mapping[str, Any] = {"excludedLocations": None} + doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + doc_request.set_excluded_location_from_options(options) + assert str( + e.value) == expected_error_message + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py index 429498b3071f..3a9e8527992e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import json +import os import time import unittest import uuid @@ -94,6 +95,18 @@ def test_wrong_hybrid_search_queries(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) + def test_hybrid_search_env_variables_async(self): + os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" + try: + query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ + "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" + results = self.test_container.query_items(query) + [item for item in results] + pytest.fail("Config was not applied properly.") + except ValueError as e: + assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + def test_hybrid_search_queries(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py index 964d9579a2e9..4223cc6bdd50 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import time import unittest import uuid @@ -98,6 +98,18 @@ async def test_wrong_hybrid_search_queries_async(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) + async def test_hybrid_search_env_variables_async(self): + os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" + try: + query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ + "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" + results = self.test_container.query_items(query) + [item async for item in results] + pytest.fail("Config was not applied properly.") + except ValueError as e: + assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + async def test_hybrid_search_queries_async(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py index e2614c5fb85f..96e3eee02936 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import unittest import uuid @@ -121,6 +121,23 @@ def test_wrong_vector_search_queries(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) + + def test_vector_search_environment_variables(self): + vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") + query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ + "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" + try: + [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] + pytest.fail("Config was not set correctly.") + except ValueError as e: + assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" + os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" + [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] + def test_ordering_distances(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies # load up previously calculated embedding for the given string diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py index 0cb031847a6f..716150358ff3 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import unittest import uuid @@ -127,6 +127,23 @@ async def test_wrong_vector_search_queries_async(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) + async def test_vector_search_environment_variables_async(self): + vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") + query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ + "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" + try: + [item async for item in self.created_large_container.query_items(query=query)] + pytest.fail("Config was not set correctly.") + except ValueError as e: + assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" + + os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" + [item async for item in self.created_large_container.query_items(query=query)] + async def test_ordering_distances_async(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies From 9d461220a745a7652ce8a82199a35288d29ba475 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 1 Apr 2025 09:18:04 -0700 Subject: [PATCH 023/152] add missing changes --- ...tition_endpoint_manager_circuit_breaker.py | 136 ++++++++++ .../azure/cosmos/_partition_health_tracker.py | 256 ++++++++++++++++++ ..._endpoint_manager_circuit_breaker_async.py | 136 ++++++++++ .../samples/excluded_locations.py | 110 ++++++++ 4 files changed, 638 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py create mode 100644 sdk/cosmos/azure-cosmos/samples/excluded_locations.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py new file mode 100644 index 000000000000..849630c83d4d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -0,0 +1,136 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os +from typing import TYPE_CHECKING + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper +from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._location_cache import EndpointOperationType +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos._constants import _Constants as Constants +if TYPE_CHECKING: + from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + + +logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") + +class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) + self.partition_health_tracker = PartitionHealthTracker() + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + """ + Check if circuit breaker is applicable for a request. + """ + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): + return False + + if request.resource_type != ResourceType.Document: + return False + + if request.operation_type != documents._OperationType.QueryPlan: + return False + + return True + + def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + """ + Create a PartitionKeyRangeWrapper object. + """ + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache: + # TODO: @tvaron3 consider moving this to a constant with other usages + if properties["_rid"] == container_rid: + target_container_link = container_link + # throw exception if it is not found + pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pkrange, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to EndpointOperationType + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + + def resolve_service_endpoint(self, request): + if self.is_circuit_breaker_applicable(request): + pkrange_wrapper = self._create_pkrange_wrapper(request) + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) + ) + return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) + + + def mark_partition_unavailable(self, request: RequestObject) -> None: + """ + Mark the partition unavailable from the given request. + """ + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py new file mode 100644 index 000000000000..d7c5c4a2cb31 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -0,0 +1,256 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for partition health tracker for circuit breaker. +""" +import os +from typing import Dict, Set, Any +from ._constants import _Constants as Constants +from azure.cosmos._location_cache import current_time_millis, EndpointOperationType +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range + + +MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 +REFRESH_INTERVAL = 60 * 1000 # milliseconds +INITIAL_UNAVAILABLE_TIME = 60 * 1000 # milliseconds +# partition is unhealthy if sdk tried to recover and failed +UNHEALTHY = "unhealthy" +# partition is unhealthy tentative when it initially marked unavailable +UNHEALTHY_TENTATIVE = "unhealthy_tentative" +# partition is healthy tentative when sdk is trying to recover +HEALTHY_TENTATIVE = "healthy_tentative" +# unavailability info keys +LAST_UNAVAILABILITY_CHECK_TIME_STAMP = "lastUnavailabilityCheckTimeStamp" +HEALTH_STATUS = "healthStatus" + + +def _has_exceeded_failure_rate_threshold( + successes: int, + failures: int, + failure_rate_threshold: int, +) -> bool: + if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: + return False + return (failures / successes * 100) >= failure_rate_threshold + +class _PartitionHealthInfo(object): + """ + This internal class keeps the health and statistics for a partition. + """ + + def __init__(self) -> None: + self.write_failure_count: int = 0 + self.read_failure_count: int = 0 + self.write_success_count: int = 0 + self.read_success_count: int = 0 + self.read_consecutive_failure_count: int = 0 + self.write_consecutive_failure_count: int = 0 + self.unavailability_info: Dict[str, Any] = {} + + + def reset_health_stats(self) -> None: + self.write_failure_count = 0 + self.read_failure_count = 0 + self.write_success_count = 0 + self.read_success_count = 0 + self.read_consecutive_failure_count = 0 + self.write_consecutive_failure_count = 0 + + +class PartitionHealthTracker(object): + """ + This internal class implements the logic for tracking health thresholds for a partition. + """ + + + def __init__(self) -> None: + # partition -> regions -> health info + self.pkrange_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} + self.last_refresh = current_time_millis() + + # TODO: @tvaron3 look for useful places to add logs + + def mark_partition_unavailable(self, pkrange_wrapper: PartitionKeyRangeWrapper, location: str) -> None: + # mark the partition key range as unavailable + self._transition_health_status_on_failure(pkrange_wrapper, location) + + def _transition_health_status_on_failure( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + current_time = current_time_millis() + if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + # healthy -> unhealthy tentative + partition_health_info = _PartitionHealthInfo() + partition_health_info.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + self.pkrange_wrapper_to_health_info[pkrange_wrapper] = { + location: partition_health_info + } + else: + region_to_partition_health = self.pkrange_wrapper_to_health_info[pkrange_wrapper] + if location in region_to_partition_health: + # healthy tentative -> unhealthy + # if the operation type is not empty, we are in the healthy tentative state + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = UNHEALTHY + # reset the last unavailability check time stamp + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] = UNHEALTHY + else: + # healthy -> unhealthy tentative + # if the operation type is empty, we are in the unhealthy tentative state + partition_health_info = _PartitionHealthInfo() + partition_health_info.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info + + def _transition_health_status_on_success( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + # healthy tentative -> healthy + self.pkrange_wrapper_to_health_info[pkrange_wrapper].pop(location, None) + + def _check_stale_partition_info(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> None: + current_time = current_time_millis() + + stale_partition_unavailability_check = int(os.getenv(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, + Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 + if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + for location, partition_health_info in self.pkrange_wrapper_to_health_info[pkrange_wrapper].items(): + elapsed_time = current_time - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] + current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + # check if the partition key range is still unavailable + if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) + or (current_health_status == UNHEALTHY_TENTATIVE + and elapsed_time > INITIAL_UNAVAILABLE_TIME)): + # unhealthy or unhealthy tentative -> healthy tentative + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + + if current_time - self.last_refresh < REFRESH_INTERVAL: + # all partition stats reset every minute + self._reset_partition_health_tracker_stats() + + + def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> Set[str]: + self._check_stale_partition_info(pkrange_wrapper) + if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + return set(self.pkrange_wrapper_to_health_info[pkrange_wrapper].keys()) + else: + return set() + + + def add_failure(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + # Retrieve the failure rate threshold from the environment. + failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, + Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) + + # Ensure that the health info dictionary is properly initialized. + if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} + if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + + # Determine attribute names and environment variables based on the operation type. + if operation_type == EndpointOperationType.WriteType: + success_attr = 'write_success_count' + failure_attr = 'write_failure_count' + consecutive_attr = 'write_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE + default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT + else: + success_attr = 'read_success_count' + failure_attr = 'read_failure_count' + consecutive_attr = 'read_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ + default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT + + # Increment failure and consecutive failure counts. + setattr(health_info, failure_attr, getattr(health_info, failure_attr) + 1) + setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) + + # Retrieve the consecutive failure threshold from the environment. + consecutive_failure_threshold = int(os.getenv(env_key, default_consec_threshold)) + + # Call the threshold checker with the current stats. + self._check_thresholds( + pkrange_wrapper, + getattr(health_info, success_attr), + getattr(health_info, failure_attr), + getattr(health_info, consecutive_attr), + location, + failure_rate_threshold, + consecutive_failure_threshold + ) + + def _check_thresholds( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + successes: int, + failures: int, + consecutive_failures: int, + location: str, + failure_rate_threshold: int, + consecutive_failure_threshold: int, + ) -> None: + + # check the failure rate was not exceeded + if _has_exceeded_failure_rate_threshold( + successes, + failures, + failure_rate_threshold + ): + self._transition_health_status_on_failure(pkrange_wrapper, location) + + # add to consecutive failures and check that threshold was not exceeded + if consecutive_failures >= consecutive_failure_threshold: + self._transition_health_status_on_failure(pkrange_wrapper, location) + + def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + # Ensure that the health info dictionary is initialized. + if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} + if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + + if operation_type == EndpointOperationType.WriteType: + health_info.write_success_count += 1 + health_info.write_consecutive_failure_count = 0 + else: + health_info.read_success_count += 1 + health_info.read_consecutive_failure_count = 0 + self._transition_health_status_on_success(pkrange_wrapper, operation_type) + + + def _reset_partition_health_tracker_stats(self) -> None: + for pkrange_wrapper in self.pkrange_wrapper_to_health_info: + for location in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py new file mode 100644 index 000000000000..849630c83d4d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -0,0 +1,136 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os +from typing import TYPE_CHECKING + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper +from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._location_cache import EndpointOperationType +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos._constants import _Constants as Constants +if TYPE_CHECKING: + from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + + +logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") + +class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) + self.partition_health_tracker = PartitionHealthTracker() + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + """ + Check if circuit breaker is applicable for a request. + """ + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): + return False + + if request.resource_type != ResourceType.Document: + return False + + if request.operation_type != documents._OperationType.QueryPlan: + return False + + return True + + def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + """ + Create a PartitionKeyRangeWrapper object. + """ + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache: + # TODO: @tvaron3 consider moving this to a constant with other usages + if properties["_rid"] == container_rid: + target_container_link = container_link + # throw exception if it is not found + pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pkrange, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to EndpointOperationType + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + + def resolve_service_endpoint(self, request): + if self.is_circuit_breaker_applicable(request): + pkrange_wrapper = self._create_pkrange_wrapper(request) + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) + ) + return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) + + + def mark_partition_unavailable(self, request: RequestObject) -> None: + """ + Mark the partition unavailable from the given request. + """ + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- + diff --git a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py new file mode 100644 index 000000000000..06228c1a8cea --- /dev/null +++ b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.cosmos import CosmosClient +from azure.cosmos.partition_key import PartitionKey +import config + +# ---------------------------------------------------------------------------------------------------------- +# Prerequisites - +# +# 1. An Azure Cosmos account - +# https://learn.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account +# +# 2. Microsoft Azure Cosmos +# pip install azure-cosmos>=4.3.0b4 +# ---------------------------------------------------------------------------------------------------------- +# Sample - demonstrates how to use excluded locations in client level and request level +# ---------------------------------------------------------------------------------------------------------- +# Note: +# This sample creates a Container to your database account. +# Each time a Container is created the account will be billed for 1 hour of usage based on +# the provisioned throughput (RU/s) of that account. +# ---------------------------------------------------------------------------------------------------------- + +HOST = config.settings["host"] +MASTER_KEY = config.settings["master_key"] + +TENANT_ID = config.settings["tenant_id"] +CLIENT_ID = config.settings["client_id"] +CLIENT_SECRET = config.settings["client_secret"] + +DATABASE_ID = config.settings["database_id"] +CONTAINER_ID = config.settings["container_id"] +PARTITION_KEY = PartitionKey(path="/id") + + +def get_test_item(num): + test_item = { + 'id': 'Item_' + str(num), + 'test_object': True, + 'lastName': 'Smith' + } + return test_item + +def clean_up_db(client): + try: + client.delete_database(DATABASE_ID) + except Exception as e: + pass + +def excluded_locations_client_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations = ['West US 3', 'West US'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # In our sample, ['West US 3', 'West US', 'East US 2'] - ['West US 3', 'West US'] => ['East US 2'], + # therefore 'East US 2' will be the read endpoint, and items will be read from 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0') + + clean_up_db(client) + +def excluded_locations_request_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations_on_client = ['West US 3', 'West US'] + excluded_locations_on_request = ['West US 3'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations_on_client + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # However, in our sample, since the excluded_locations` were passed with the read request, the `excluded_location` + # will be replaced with the locations from request, ['West US 3']. The `excluded_locations` on request always takes + # the highest priority! + # With the excluded_locations on request, the read endpoints will be ['West US', 'East US 2'] + # ['West US 3', 'West US', 'East US 2'] - ['West US 3'] => ['West US', 'East US 2'] + # Therefore, items will be read from 'West US' or 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0', excluded_locations=excluded_locations_on_request) + + clean_up_db(client) + +if __name__ == "__main__": + # excluded_locations_client_level_sample() + excluded_locations_request_level_sample() From 4efa9ad0ed683b5e0553ec1c8c733d5f1e0ba3c5 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 1 Apr 2025 14:38:01 -0700 Subject: [PATCH 024/152] fix mypy errors --- .../tests/_fault_injection_transport.py | 22 ++++++++--------- .../test_fault_injection_transport_async.py | 24 +++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 37c00667544a..7f83608b3cd7 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -27,7 +27,7 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable, Optional +from typing import Callable, Optional, Any import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse @@ -41,7 +41,7 @@ class FaultInjectionTransport(AioHttpTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') logger.setLevel(logging.DEBUG) - def __init__(self, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner: bool = True, **config): self.faults = [] self.requestTransformations = [] self.responseTransformations = [] @@ -64,7 +64,7 @@ def __first_item(iterable, condition=lambda x: True): """ return next((x for x in iterable if condition(x)), None) - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Optional[MutableMapping[str, str]] = None, **config) -> AsyncHttpResponse: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) @@ -201,7 +201,7 @@ async def transform_topology_mwr( first_region_name: str, second_region_name: str, r: HttpRequest, - inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> AsyncHttpResponse: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): @@ -228,7 +228,7 @@ async def transform_topology_mwr( return response class MockHttpResponse(AioHttpTransportResponse): - def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, any]]): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, Any]]): self.request: HttpRequest = request # This is actually never None, and set by all implementations after the call to # __init__ of this class. This class is also a legacy impl, so it's risky to change it @@ -239,19 +239,19 @@ def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict self.reason: Optional[str] = None self.content_type: Optional[str] = None self.block_size: int = 4096 # Default to same as R - self.content: Optional[dict[str, any]] = None + self.content: Optional[dict[str, Any]] = None self.json_text: Optional[str] = None self.bytes: Optional[bytes] = None if content: - self.content:Optional[dict[str, any]] = content - self.json_text:Optional[str] = json.dumps(content) - self.bytes:bytes = self.json_text.encode("utf-8") + self.content = content + self.json_text = json.dumps(content) + self.bytes = self.json_text.encode("utf-8") - def body(self) -> bytes: + def body(self) -> Optional[bytes]: return self.bytes - def text(self, encoding: Optional[str] = None) -> str: + def text(self, encoding: Optional[str] = None) -> Optional[str]: return self.json_text async def load_body(self) -> None: diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index d09f017febbf..71b6824ef240 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -61,7 +61,7 @@ def setup_class(cls): cls.mgmt_client = CosmosClient(host, master_key, consistency_level="Session", connection_policy=connection_policy, logger=logger) - created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) + created_database = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( cls.single_partition_container_name, @@ -71,7 +71,7 @@ def setup_class(cls): @classmethod def teardown_class(cls): logger.info("tearing down class: {}".format(cls.__name__)) - created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) + created_database = cls.mgmt_client.get_database_client(cls.database_id) try: asyncio.run(asyncio.wait_for( created_database.delete_container(cls.single_partition_container_name), @@ -100,7 +100,7 @@ def cleanup_method(initialized_objects: dict[str, Any]): except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") - async def test_throws_injected_error(self, setup): + async def test_throws_injected_error(self, setup: object): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, @@ -126,7 +126,7 @@ async def test_throws_injected_error(self, setup): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_succeeds(self, setup): + async def test_swr_mrr_succeeds(self, setup: object): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -167,18 +167,18 @@ async def test_swr_mrr_succeeds(self, setup): request: HttpRequest = created_document.get_response_headers()["_request"] # Validate the response comes from "Write Region" (the write region) assert request.url.startswith(expected_write_region_uri) - start:float = time.perf_counter() + start: float = time.perf_counter() while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_read_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_read_succeeds(self, setup): + async def test_swr_mrr_region_down_read_succeeds(self, setup: object): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -232,14 +232,14 @@ async def test_swr_mrr_region_down_read_succeeds(self, setup): while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) assert request.url.startswith(expected_write_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): + async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -298,7 +298,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) assert request.url.startswith(expected_write_region_uri) @@ -306,7 +306,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_mwr_succeeds(self, setup): + async def test_mwr_succeeds(self, setup: object): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host custom_transport = FaultInjectionTransport() @@ -344,7 +344,7 @@ async def test_mwr_succeeds(self, setup): while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "East US" (the most preferred read-only region) assert request.url.startswith(first_region_uri) From d86d381c66b54876c6a1f1773a05edc742361680 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 10:01:53 -0700 Subject: [PATCH 025/152] Refactored gem for ppcb and hooked up retryconfigurations with failure tracking --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- ...tition_endpoint_manager_circuit_breaker.py | 80 ++--------- ...n_endpoint_manager_circuit_breaker_core.py | 131 ++++++++++++++++++ .../azure/cosmos/_location_cache.py | 4 +- .../azure/cosmos/_partition_health_tracker.py | 9 +- .../azure/cosmos/_request_object.py | 11 +- .../azure/cosmos/aio/_asynchronous_request.py | 4 + .../aio/_cosmos_client_connection_async.py | 4 +- ..._endpoint_manager_circuit_breaker_async.py | 87 ++---------- .../azure/cosmos/aio/_retry_utility_async.py | 12 +- 10 files changed, 179 insertions(+), 165 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index c5f836ce2a03..475203a1e213 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.10.0b3 (Unreleased) #### Features Added -* Per partition circuit breaker support. It can be enabled through environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See +* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 849630c83d4d..072b946c1402 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -21,24 +21,17 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging -import os from typing import TYPE_CHECKING -from azure.cosmos import documents +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos._partition_health_tracker import PartitionHealthTracker -from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager -from azure.cosmos._location_cache import EndpointOperationType from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType, HttpHeaders -from azure.cosmos._constants import _Constants as Constants if TYPE_CHECKING: - from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + from azure.cosmos._cosmos_client_connection import CosmosClientConnection -logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): """ @@ -49,88 +42,33 @@ class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) - self.partition_health_tracker = PartitionHealthTracker() + self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: """ Check if circuit breaker is applicable for a request. """ - if not request: - return False - - circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, - Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" - if not circuit_breaker_enabled: - return False - - if (not self.location_cache.can_use_multiple_write_locations_for_request(request) - and documents._OperationType.IsWriteOperation(request.operation_type)): - return False - - if request.resource_type != ResourceType.Document: - return False - - if request.operation_type != documents._OperationType.QueryPlan: - return False - - return True - - def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - """ - Create a PartitionKeyRangeWrapper object. - """ - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache: - # TODO: @tvaron3 consider moving this to a constant with other usages - if properties["_rid"] == container_rid: - target_container_link = container_link - # throw exception if it is not found - pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pkrange, container_rid) + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) def record_failure( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to EndpointOperationType - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + self.global_partition_endpoint_manager_core.record_failure(request) def resolve_service_endpoint(self, request): - if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pkrange_wrapper(request) - request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) - ) + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) - def mark_partition_unavailable(self, request: RequestObject) -> None: """ Mark the partition unavailable from the given request. """ - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) def record_success( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to either Read or Write - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) - -# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- + self.global_partition_endpoint_manager_core.record_success(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py new file mode 100644 index 000000000000..23a7c50047ca --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -0,0 +1,131 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._location_cache import EndpointOperationType, LocationCache +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos._constants import _Constants as Constants + + +logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreakerCore") + +class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client, location_cache: LocationCache): + self.partition_health_tracker = PartitionHealthTracker() + self.location_cache = location_cache + self.client = client + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + """ + Check if circuit breaker is applicable for a request. + """ + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): + return False + + if request.resource_type != ResourceType.Document: + return False + + if request.operation_type != documents._OperationType.QueryPlan: + return False + + return True + + def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + """ + Create a PartitionKeyRangeWrapper object. + """ + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache: + # TODO: @tvaron3 consider moving this to a constant with other usages + if properties["_rid"] == container_rid: + target_container_link = container_link + # throw exception if it is not found + pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pkrange, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to EndpointOperationType + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + + def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: + if self.is_circuit_breaker_applicable(request): + pkrange_wrapper = self._create_pkrange_wrapper(request) + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) + ) + return request + + def mark_partition_unavailable(self, request: RequestObject) -> None: + """ + Mark the partition unavailable from the given request. + """ + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index d40a99f7c69f..ae6cac6feddc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -25,7 +25,7 @@ import collections import logging import time -from typing import Set +from typing import Set, Optional from urllib.parse import urlparse from . import documents @@ -232,13 +232,13 @@ def get_ordered_write_locations(self): def get_ordered_read_locations(self): return self.account_read_locations - # Todo: @tvaron3 client should be appeneded to if using circuitbreaker exclude regions def _get_configured_excluded_locations(self, request): # If excluded locations were configured on request, use request level excluded locations. excluded_locations = request.excluded_locations if excluded_locations is None: # If excluded locations were only configured on client(connection_policy), use client level excluded_locations = self.connection_policy.ExcludedLocations + excluded_locations.union(request.excluded_locations_circuit_breaker) return excluded_locations def get_applicable_read_regional_endpoints(self, request): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index d7c5c4a2cb31..74665a5e7eb5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -22,7 +22,7 @@ """Internal class for partition health tracker for circuit breaker. """ import os -from typing import Dict, Set, Any +from typing import Dict, Set, Any, Optional from ._constants import _Constants as Constants from azure.cosmos._location_cache import current_time_millis, EndpointOperationType from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range @@ -164,7 +164,12 @@ def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> S return set() - def add_failure(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + def add_failure( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + operation_type: str, + location: Optional[str] + ) -> None: # Retrieve the failure rate threshold from the environment. failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index afc9fa4d30a9..50dd4c7fc4d1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,7 +21,7 @@ """Represents a request object. """ -from typing import Optional, Mapping, Any, List, Dict +from typing import Optional, Mapping, Any, Dict, Set class RequestObject(object): @@ -41,7 +41,8 @@ def __init__( self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None - self.excluded_locations: Optional[List[str]] = None + self.excluded_locations: Optional[Set[str]] = None + self.excluded_locations_circuit_breaker: Set[str] = set() def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -83,7 +84,5 @@ def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None if self._can_set_excluded_location(options): self.excluded_locations = options['excludedLocations'] - def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: - if self.excluded_locations: - self.excluded_locations.extend(excluded_locations) - self.excluded_locations = excluded_locations + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: + self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 81430d8df42c..25f6ac203d85 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -101,6 +101,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -111,6 +113,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 0d3d61e87fa0..8fbdb0fb9a83 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -54,7 +54,7 @@ from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState -from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreaker +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .._routing import routing_range from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants @@ -169,7 +169,7 @@ def __init__( # pylint: disable=too-many-statements # Keeps the latest response headers from the server. self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreakerAsync(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, AsyncHTTPPolicy): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 849630c83d4d..71bb628c31a0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -21,26 +21,19 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging -import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional -from azure.cosmos import documents +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos._partition_health_tracker import PartitionHealthTracker -from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager -from azure.cosmos._location_cache import EndpointOperationType from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType, HttpHeaders -from azure.cosmos._constants import _Constants as Constants if TYPE_CHECKING: from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection -logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") -class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): """ This internal class implements the logic for partition endpoint management for geo-replicated database accounts. @@ -48,89 +41,33 @@ class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): def __init__(self, client: "CosmosClientConnection"): - super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) - self.partition_health_tracker = PartitionHealthTracker() + super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).__init__(client) + self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: """ Check if circuit breaker is applicable for a request. """ - if not request: - return False - - circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, - Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" - if not circuit_breaker_enabled: - return False - - if (not self.location_cache.can_use_multiple_write_locations_for_request(request) - and documents._OperationType.IsWriteOperation(request.operation_type)): - return False - - if request.resource_type != ResourceType.Document: - return False - - if request.operation_type != documents._OperationType.QueryPlan: - return False - - return True - - def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - """ - Create a PartitionKeyRangeWrapper object. - """ - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache: - # TODO: @tvaron3 consider moving this to a constant with other usages - if properties["_rid"] == container_rid: - target_container_link = container_link - # throw exception if it is not found - pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pkrange, container_rid) + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) def record_failure( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to EndpointOperationType - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + self.global_partition_endpoint_manager_core.record_failure(request) def resolve_service_endpoint(self, request): - if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pkrange_wrapper(request) - request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) - ) - return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) - + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) + return super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).resolve_service_endpoint(request) def mark_partition_unavailable(self, request: RequestObject) -> None: """ Mark the partition unavailable from the given request. """ - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) def record_success( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to either Read or Write - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) - -# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- - + self.global_partition_endpoint_manager_core.record_success(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index c020a5d4e31c..c5613530994d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -28,6 +28,7 @@ from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import AsyncRetryPolicy +from ._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .. import _default_retry_policy, _database_account_retry_policy from .. import _endpoint_discovery_retry_policy from .. import _gone_retry_policy @@ -257,6 +258,8 @@ async def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) + request_params = request.context.options.get('request_params', None) + global_endpoint_manager = request.context.options.get('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -281,8 +284,7 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if not _has_database_account_header(request.http_request.headers): - # TODO: @tvaron3 record failure here - request.context.global_endpoint_manager.record_failure(request.context.options['request']) + global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -300,8 +302,7 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): - # TODO: @tvaron3 record failure here - request.context.global_endpoint_manager.record_failure(request.context.options['request']) + global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -316,8 +317,7 @@ async def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): - # TODO: @tvaron3 record failure here ? Not sure - request.context.global_endpoint_manager.record_failure(request.context.options['request']) + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) From 9e51011b472a89ad34b4cc001d586da4ab350fc7 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 12:05:14 -0700 Subject: [PATCH 026/152] fix use multiple write locations bug --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- .../azure-cosmos/azure/cosmos/_global_endpoint_manager.py | 4 ++-- .../azure/cosmos/_service_request_retry_policy.py | 2 +- .../azure/cosmos/aio/_global_endpoint_manager_async.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index ab16ba4f1cd9..97a88849f18c 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.10.0b5 (Unreleased) #### Features Added -* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302) +* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 8aa7c388a9b6..52b0b64cc27e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -61,8 +61,8 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self): - return self.location_cache.can_use_multiple_write_locations() + def get_use_multiple_write_locations(self, request): + return self.location_cache.can_use_multiple_write_locations_for_request() def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 8630714bc6f3..18dd17c9f5ad 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -61,7 +61,7 @@ def ShouldRetry(self): self.request.last_routed_location_endpoint_within_region = self.request.location_endpoint_to_route if (_OperationType.IsReadOnlyOperation(self.request.operation_type) - or self.global_endpoint_manager.get_use_multiple_write_locations()): + or self.global_endpoint_manager.get_use_multiple_write_locations(self.request)): self.update_location_cache() # We just directly got to the next location in case of read requests # We don't retry again on the same region for regional endpoint diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index f2bf7dcbe5fd..f001772b246a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -33,6 +33,7 @@ from .. import _constants as constants from .. import exceptions from .._location_cache import LocationCache, current_time_millis +from .._request_object import RequestObject # pylint: disable=protected-access @@ -63,9 +64,8 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - # TODO: @tvaron3 fix this - def get_use_multiple_write_locations(self): - return self.location_cache.can_use_multiple_write_locations() + def get_use_multiple_write_locations(self, request: RequestObject): + return self.location_cache.can_use_multiple_write_locations_for_request(request) def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime From 8d276515de28d8850ab4d5b0d5e54486f9ae6f75 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 13:25:42 -0700 Subject: [PATCH 027/152] clean up and revert vs env variable changes --- .../azure-cosmos/azure/cosmos/_constants.py | 6 ------ .../azure/cosmos/_cosmos_client_connection.py | 11 +---------- .../aio/execution_dispatcher.py | 6 ++---- .../execution_dispatcher.py | 11 ++++------- .../azure/cosmos/_global_endpoint_manager.py | 3 --- ...tition_endpoint_manager_circuit_breaker.py | 2 +- ...n_endpoint_manager_circuit_breaker_core.py | 2 +- .../cosmos/_service_request_retry_policy.py | 2 +- .../cosmos/_timeout_failover_retry_policy.py | 1 - .../aio/_cosmos_client_connection_async.py | 11 +---------- .../aio/_global_endpoint_manager_async.py | 3 --- .../tests/test_query_hybrid_search.py | 13 ------------- .../tests/test_query_hybrid_search_async.py | 14 +------------- .../tests/test_query_vector_similarity.py | 19 +------------------ .../test_query_vector_similarity_async.py | 19 +------------------ 15 files changed, 14 insertions(+), 109 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 2af40c74d77a..cf029179f1a1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -45,12 +45,6 @@ class _Constants: EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations" # Environment variables - NON_STREAMING_ORDER_BY_DISABLED_CONFIG: str = "AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY" - NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT: str = "False" - HS_MAX_ITEMS_CONFIG: str = "AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS" - HS_MAX_ITEMS_CONFIG_DEFAULT: int = 1000 - MAX_ITEM_BUFFER_VS_CONFIG: str = "AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH" - MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000 CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" # Only applicable when circuit breaker is enabled ------------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 7bb379a6b6e7..3c8a11218738 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2045,7 +2045,6 @@ def PatchItem( # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2135,7 +2134,6 @@ def _Batch( documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2197,7 +2195,6 @@ def DeleteAllItemsByPartitionKey( "partitionkey", documents._OperationType.Delete, options) request_params = RequestObject("partitionkey", documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2655,7 +2652,6 @@ def Create( request_params = RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2703,7 +2699,6 @@ def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2748,7 +2743,6 @@ def Replace( # Replace will use WriteEndpoint since it uses PUT operation request_params = RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2791,7 +2785,6 @@ def Read( # Read will use ReadEndpoint since it uses GET operation request_params = RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2832,7 +2825,6 @@ def DeleteResource( # Delete will use WriteEndpoint since it uses DELETE operation request_params = RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3206,8 +3198,7 @@ def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs: documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, - Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": + if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 9b4d32a4598b..70df36d6d015 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -34,7 +34,6 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import StatusCodes -from ..._constants import _Constants as Constants # pylint: disable=protected-access @@ -118,9 +117,8 @@ async def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, - Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): - raise ValueError("Executing a vector search query with more items than the max is not allowed. " + + if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): + raise ValueError("Executing a vector search query with more items than the max is not allowed." + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator =\ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 453a4dc38d25..1155e537c68c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -26,7 +26,6 @@ import json import os from azure.cosmos.exceptions import CosmosHttpResponseError -from .._constants import _Constants as Constants from azure.cosmos._execution_context import endpoint_component, multi_execution_aggregator from azure.cosmos._execution_context import non_streaming_order_by_aggregator, hybrid_search_aggregator from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase @@ -57,9 +56,8 @@ def _verify_valid_hybrid_search_query(hybrid_search_query_info): raise ValueError("Executing a hybrid search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your hybrid search query.") - if hybrid_search_query_info['take'] > int(os.environ.get(Constants.HS_MAX_ITEMS_CONFIG, - Constants.HS_MAX_ITEMS_CONFIG_DEFAULT)): - raise ValueError("Executing a hybrid search query with more items than the max is not allowed. " + + if hybrid_search_query_info['take'] > os.environ.get('AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS', 1000): + raise ValueError("Executing a hybrid search query with more items than the max is not allowed." + "Please ensure you are using a limit smaller than the max, or change the max.") @@ -151,9 +149,8 @@ def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, - Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): - raise ValueError("Executing a vector search query with more items than the max is not allowed. " + + if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): + raise ValueError("Executing a vector search query with more items than the max is not allowed." + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator = \ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 35a4c60c6ca4..8cf9d06d5486 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -58,9 +58,6 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self, request): - return self.location_cache.can_use_multiple_write_locations_for_request() - def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 072b946c1402..b95ef0a2a7da 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -26,7 +26,7 @@ from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject if TYPE_CHECKING: from azure.cosmos._cosmos_client_connection import CosmosClientConnection diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 23a7c50047ca..142c51a1ed19 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -128,4 +128,4 @@ def record_success( pkrange_wrapper = self._create_pkrange_wrapper(request) self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) -# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- \ No newline at end of file +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 18dd17c9f5ad..edd15f20337f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -61,7 +61,7 @@ def ShouldRetry(self): self.request.last_routed_location_endpoint_within_region = self.request.location_endpoint_to_route if (_OperationType.IsReadOnlyOperation(self.request.operation_type) - or self.global_endpoint_manager.get_use_multiple_write_locations(self.request)): + or self.global_endpoint_manager.can_use_multiple_write_locations(self.request)): self.update_location_cache() # We just directly got to the next location in case of read requests # We don't retry again on the same region for regional endpoint diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index 505a8edf3d06..f70e27bae70c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -5,7 +5,6 @@ Cosmos database service. """ from azure.cosmos.documents import _OperationType -from azure.cosmos.http_constants import HttpHeaders class _TimeoutFailoverRetryPolicy(object): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index cfef48d64765..09add40f5785 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -774,7 +774,6 @@ async def Create( request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -915,7 +914,6 @@ async def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1218,7 +1216,6 @@ async def Read( # Read will use ReadEndpoint since it uses GET operation request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1478,7 +1475,6 @@ async def PatchItem( # Patch will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1584,7 +1580,6 @@ async def Replace( # Replace will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1909,7 +1904,6 @@ async def DeleteResource( # Delete will use WriteEndpoint since it uses DELETE operation request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2024,7 +2018,6 @@ async def _Batch( documents._OperationType.Batch, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -3219,8 +3212,7 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kw documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, - Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": + if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + @@ -3286,7 +3278,6 @@ async def DeleteAllItemsByPartitionKey( request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 0917b3de94ca..00438cc2214e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -62,9 +62,6 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self, request: RequestObject): - return self.location_cache.can_use_multiple_write_locations_for_request(request) - def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py index 3a9e8527992e..429498b3071f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import json -import os import time import unittest import uuid @@ -95,18 +94,6 @@ def test_wrong_hybrid_search_queries(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) - def test_hybrid_search_env_variables_async(self): - os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" - try: - query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ - "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" - results = self.test_container.query_items(query) - [item for item in results] - pytest.fail("Config was not applied properly.") - except ValueError as e: - assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - def test_hybrid_search_queries(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py index 4223cc6bdd50..964d9579a2e9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import os + import time import unittest import uuid @@ -98,18 +98,6 @@ async def test_wrong_hybrid_search_queries_async(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) - async def test_hybrid_search_env_variables_async(self): - os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" - try: - query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ - "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" - results = self.test_container.query_items(query) - [item async for item in results] - pytest.fail("Config was not applied properly.") - except ValueError as e: - assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - async def test_hybrid_search_queries_async(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py index 96e3eee02936..e2614c5fb85f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import os + import unittest import uuid @@ -121,23 +121,6 @@ def test_wrong_vector_search_queries(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) - - def test_vector_search_environment_variables(self): - vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") - query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ - "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" - try: - [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] - pytest.fail("Config was not set correctly.") - except ValueError as e: - assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" - [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] - def test_ordering_distances(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies # load up previously calculated embedding for the given string diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py index 716150358ff3..0cb031847a6f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import os + import unittest import uuid @@ -127,23 +127,6 @@ async def test_wrong_vector_search_queries_async(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) - async def test_vector_search_environment_variables_async(self): - vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") - query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ - "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" - try: - [item async for item in self.created_large_container.query_items(query=query)] - pytest.fail("Config was not set correctly.") - except ValueError as e: - assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" - - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" - [item async for item in self.created_large_container.query_items(query=query)] - async def test_ordering_distances_async(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies From 90fe5c2ff3a58b9c74a2fb2f17de51e4c3553f80 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 14:26:31 -0700 Subject: [PATCH 028/152] remove async await --- .../azure-cosmos/azure/cosmos/aio/_asynchronous_request.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index f8ebf6ccdbb8..81430d8df42c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -117,7 +117,6 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p response = response.http_response headers = copy.copy(response.headers) - await response.load_body() data = response.body() if data: data = data.decode("utf-8") From 206be781b0cfa5c87a1ca69e50b5407cdca31e6e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 16:25:50 -0700 Subject: [PATCH 029/152] refactor and fix tests --- ...py => _fault_injection_transport_async.py} | 29 ++--- .../azure-cosmos/tests/test_crud_async.py | 3 - .../test_fault_injection_transport_async.py | 104 ++++++++---------- 3 files changed, 59 insertions(+), 77 deletions(-) rename sdk/cosmos/azure-cosmos/tests/{_fault_injection_transport.py => _fault_injection_transport_async.py} (91%) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py similarity index 91% rename from sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py rename to sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 7f83608b3cd7..230d8f89dfe5 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -27,30 +27,31 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable, Optional, Any +from typing import Callable, Optional, Any, Dict, List import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse +from azure.cosmos import documents import test_config from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -class FaultInjectionTransport(AioHttpTransport): +class FaultInjectionTransportAsync(AioHttpTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') logger.setLevel(logging.DEBUG) def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner: bool = True, **config): - self.faults = [] - self.requestTransformations = [] - self.responseTransformations = [] + self.faults: List[Dict[str, Any]] = [] + self.requestTransformations: List[Dict[str, Any]] = [] + self.responseTransformations: List[Dict[str, Any]] = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], asyncio.Task[Exception]]): self.faults.append({"predicate": predicate, "apply": fault_factory}) - def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], AioHttpTransportResponse]], AioHttpTransportResponse]): self.responseTransformations.append({ "predicate": predicate, "apply": response_transformation}) @@ -142,10 +143,8 @@ def predicate_is_document_operation(r: HttpRequest) -> bool: @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: - is_write_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' - and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Read' - and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'ReadFeed' - and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Query') + is_write_document_operation = documents._OperationType.IsWriteOperation( + str(r.headers.get('x-ms-thinclient-proxy-operation-type'))) return is_write_document_operation and uri_prefix in r.url @@ -173,14 +172,12 @@ async def error_region_down() -> Exception: async def transform_topology_swr_mrr( write_region_name: str, read_region_name: str, - r: HttpRequest, - inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + inner: Callable[[],asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): return response - await response.load_body() data = response.body() if response.status_code == 200 and data: data = data.decode("utf-8") @@ -200,14 +197,12 @@ async def transform_topology_swr_mrr( async def transform_topology_mwr( first_region_name: str, second_region_name: str, - r: HttpRequest, - inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> AsyncHttpResponse: + inner: Callable[[], asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): return response - await response.load_body() data = response.body() if response.status_code == 200 and data: data = data.decode("utf-8") @@ -251,7 +246,7 @@ def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict def body(self) -> Optional[bytes]: return self.bytes - def text(self, encoding: Optional[str] = None) -> Optional[str]: + def text(self) -> Optional[str]: return self.json_text async def load_body(self) -> None: diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 160b5bb93cc6..ca6cfad8287d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -4,8 +4,6 @@ """End-to-end test. """ -import json -import os.path import time import unittest import urllib.parse as urllib @@ -17,7 +15,6 @@ from azure.core.exceptions import AzureError, ServiceResponseError from azure.core.pipeline.transport import AsyncioRequestsTransport, AsyncioRequestsTransportResponse -import azure.cosmos._base as base import azure.cosmos.documents as documents import azure.cosmos.exceptions as exceptions import test_config diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 71b6824ef240..2227019cbe63 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -12,18 +12,18 @@ import pytest from azure.core.pipeline.transport import AioHttpTransport +from azure.core.pipeline.transport._aiohttp import AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse import test_config -from _fault_injection_transport import FaultInjectionTransport +from _fault_injection_transport_async import FaultInjectionTransportAsync from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError -COLLECTION = "created_collection" -MGMT_TIMEOUT = 3.0 +MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -32,21 +32,16 @@ master_key = test_config.TestConfig.masterKey connection_policy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - -@pytest.fixture() -def setup(): - return - +single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) @pytest.mark.cosmosEmulator @pytest.mark.asyncio -@pytest.mark.usefixtures("setup") class TestFaultInjectionTransportAsync: @classmethod def setup_class(cls): logger.info("starting class: {} execution".format(cls.__name__)) - cls.host = test_config.TestConfig.host - cls.master_key = test_config.TestConfig.masterKey + cls.host = host + cls.master_key = master_key if (cls.master_key == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): @@ -55,12 +50,12 @@ def setup_class(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.connection_policy = test_config.TestConfig.connectionPolicy - cls.database_id = test_config.TestConfig.TEST_DATABASE_ID - cls.single_partition_container_name= os.path.basename(__file__) + str(uuid.uuid4()) + cls.connection_policy = connection_policy + cls.database_id = TEST_DATABASE_ID + cls.single_partition_container_name = single_partition_container_name - cls.mgmt_client = CosmosClient(host, master_key, consistency_level="Session", - connection_policy=connection_policy, logger=logger) + cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", + connection_policy=cls.connection_policy, logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( @@ -89,7 +84,7 @@ def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, connection_policy=connection_policy, transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(self.single_partition_container_name) + container: ContainerProxy = db.get_container_client(single_partition_container_name) return {"client": client, "db": db, "col": container} @staticmethod @@ -100,16 +95,16 @@ def cleanup_method(initialized_objects: dict[str, Any]): except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") - async def test_throws_injected_error(self, setup: object): + async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, 'name': 'sample document', 'key': 'value'} - custom_transport = FaultInjectionTransport() - predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, id_value) - custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( + custom_transport = FaultInjectionTransportAsync() + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 500, CosmosHttpResponseError( status_code=502, @@ -126,26 +121,25 @@ async def test_throws_injected_error(self, setup: object): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_succeeds(self, setup: object): + async def test_swr_mrr_succeeds(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) custom_transport.add_fault( is_write_operation_in_read_region_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[], AsyncHttpResponse]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, @@ -178,34 +172,33 @@ async def test_swr_mrr_succeeds(self, setup: object): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_read_succeeds(self, setup: object): + async def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) custom_transport.add_fault( is_write_operation_in_read_region_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) # Inject rule to simulate regional outage in "Read Region" is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_targets_region(r, expected_read_region_uri) custom_transport.add_fault( is_request_to_read_region, - lambda r: asyncio.create_task(FaultInjectionTransport.error_region_down())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], AsyncHttpResponse]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, @@ -239,26 +232,26 @@ async def test_swr_mrr_region_down_read_succeeds(self, setup: object): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): + async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) custom_transport.add_fault( is_write_operation_in_read_region_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) # Inject rule to simulate regional outage in "Read Region" is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and \ - FaultInjectionTransport.predicate_is_document_operation(r) + r: FaultInjectionTransportAsync.predicate_targets_region(r, expected_read_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) custom_transport.add_fault( is_request_to_read_region, - lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 35000, CosmosHttpResponseError( status_code=502, @@ -266,12 +259,11 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, @@ -306,19 +298,17 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_mwr_succeeds(self, setup: object): + async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") - second_region_uri: str = test_config.TestConfig.local_host - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( first_region_name="First Region", second_region_name="Second Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, From 622589fced33fa4ec415be94bbe8cb051156a804 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 16:31:59 -0700 Subject: [PATCH 030/152] Fix refactoring --- .../tests/_fault_injection_transport_async.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 230d8f89dfe5..dec1699b8742 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -66,38 +66,38 @@ def __first_item(iterable, condition=lambda x: True): return next((x for x in iterable if condition(x)), None) async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Optional[MutableMapping[str, str]] = None, **config) -> AsyncHttpResponse: - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any - first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) + first_fault_factory = FaultInjectionTransportAsync.__first_item(iter(self.faults), lambda f: f["predicate"](request)) if first_fault_factory: - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") injected_error = await first_fault_factory["apply"](request) - FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error)) + FaultInjectionTransportAsync.logger.info("Found to-be-injected error {}".format(injected_error)) raise injected_error # apply the chain of request transformations with matching predicates if any matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) for currentTransformation in matching_request_transformations: - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") request = await currentTransformation["apply"](request) - first_response_transformation = FaultInjectionTransport.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) + first_response_transformation = FaultInjectionTransportAsync.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") get_response_task = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) - FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + FaultInjectionTransportAsync.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") if first_response_transformation: - FaultInjectionTransport.logger.info(f"Invoking response transformation") + FaultInjectionTransportAsync.logger.info(f"Invoking response transformation") response = await first_response_transformation["apply"](request, lambda: get_response_task) response.headers["_request"] = request - FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") + FaultInjectionTransportAsync.logger.info(f"Received response transformation result with status code {response.status_code}") return response else: - FaultInjectionTransport.logger.info(f"Sending request to {request.url}") + FaultInjectionTransportAsync.logger.info(f"Sending request to {request.url}") response = await get_response_task response.headers["_request"] = request - FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") + FaultInjectionTransportAsync.logger.info(f"Received response with status code {response.status_code}") return response @staticmethod @@ -125,8 +125,8 @@ def predicate_req_payload_contains_id(r: HttpRequest, id_value: str): @staticmethod def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool: - return (FaultInjectionTransport.predicate_url_contains_id(r, id_value) - or FaultInjectionTransport.predicate_req_payload_contains_id(r, id_value)) + return (FaultInjectionTransportAsync.predicate_url_contains_id(r, id_value) + or FaultInjectionTransportAsync.predicate_req_payload_contains_id(r, id_value)) @staticmethod def predicate_is_database_account_call(r: HttpRequest) -> bool: @@ -175,7 +175,7 @@ async def transform_topology_swr_mrr( inner: Callable[[],asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() - if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): return response data = response.body() @@ -187,9 +187,9 @@ async def transform_topology_swr_mrr( readable_locations[0]["name"] = write_region_name writable_locations[0]["name"] = write_region_name readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host}) - FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request - return FaultInjectionTransport.MockHttpResponse(request, 200, result) + return FaultInjectionTransportAsync.MockHttpResponse(request, 200, result) return response @@ -200,7 +200,7 @@ async def transform_topology_mwr( inner: Callable[[], asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() - if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): return response data = response.body() @@ -216,9 +216,9 @@ async def transform_topology_mwr( writable_locations.append( {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) result["enableMultipleWriteLocations"] = True - FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request - return FaultInjectionTransport.MockHttpResponse(request, 200, result) + return FaultInjectionTransportAsync.MockHttpResponse(request, 200, result) return response From 4dd17ea97d5a877635f9f0472463fb9196144a2a Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 17:01:29 -0700 Subject: [PATCH 031/152] Fix tests --- .../tests/test_fault_injection_transport_async.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 2227019cbe63..05fbd6f1f658 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -30,7 +30,6 @@ host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey -connection_policy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) @@ -49,13 +48,10 @@ def setup_class(cls): "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - - cls.connection_policy = connection_policy cls.database_id = TEST_DATABASE_ID cls.single_partition_container_name = single_partition_container_name - cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", - connection_policy=cls.connection_policy, logger=logger) + cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( @@ -81,8 +77,7 @@ def teardown_class(cls): def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", - connection_policy=connection_policy, transport=custom_transport, - logger=logger, enable_diagnostics_logging=True, **kwargs) + transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) container: ContainerProxy = db.get_container_client(single_partition_container_name) return {"client": client, "db": db, "col": container} From e631b74e51d2c136b79cb3d54a5263f5c44eb1db Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 17:26:00 -0700 Subject: [PATCH 032/152] fix tests --- sdk/cosmos/azure-cosmos/pytest.ini | 2 +- .../tests/_fault_injection_transport_async.py | 8 +-- .../test_fault_injection_transport_async.py | 49 ++++++++++++++++++- 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index e211052edef0..647aac1464f8 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -1,3 +1,3 @@ [pytest] markers = - cosmosEmulator: marks tests as depending in Cosmos DB Emulator \ No newline at end of file + cosmosEmulator: marks tests as depending in Cosmos DB Emulator diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index dec1699b8742..428059231796 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -27,7 +27,7 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable, Optional, Any, Dict, List +from typing import Callable, Optional, Any, Dict, List, Awaitable import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse @@ -48,7 +48,7 @@ def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None self.responseTransformations: List[Dict[str, Any]] = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], asyncio.Task[Exception]]): + def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], Awaitable[Exception]]): self.faults.append({"predicate": predicate, "apply": fault_factory}) def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], AioHttpTransportResponse]], AioHttpTransportResponse]): @@ -172,7 +172,7 @@ async def error_region_down() -> Exception: async def transform_topology_swr_mrr( write_region_name: str, read_region_name: str, - inner: Callable[[],asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): @@ -197,7 +197,7 @@ async def transform_topology_swr_mrr( async def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 05fbd6f1f658..0a186afbb875 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -75,7 +75,8 @@ def teardown_class(cls): except Exception as closeError: logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + @staticmethod + def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) @@ -247,7 +248,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjection custom_transport.add_fault( is_request_to_read_region, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 35000, + 500, CosmosHttpResponseError( status_code=502, message="Some random reverse proxy error.")))) @@ -336,5 +337,49 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + # add a test for delays + # add test for complete failures + + async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync"): + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransportAsync() + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + start:float = time.perf_counter() + while (time.perf_counter() - start) < 2: + upsert_document = await container.upsert_item(body=document_definition) + request = upsert_document.get_response_headers()["_request"] + assert request.url.startswith(second_region_uri) + read_document = await container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(second_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + if __name__ == '__main__': unittest.main() From 2d5f0d7d624658c9f4a36d205010904b6ba3aa6d Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 17:52:56 -0700 Subject: [PATCH 033/152] add more tests --- .../tests/_fault_injection_transport_async.py | 4 +-- .../test_fault_injection_transport_async.py | 25 ++++++++++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 428059231796..f5b9c8b126dd 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -26,9 +26,7 @@ import json import logging import sys -from collections.abc import MutableMapping -from typing import Callable, Optional, Any, Dict, List, Awaitable - +from typing import Callable, Optional, Any, Dict, List, Awaitable, MutableMapping import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 0a186afbb875..c6e0e5b39b5c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -101,17 +101,21 @@ async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): custom_transport = FaultInjectionTransportAsync() predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 500, + 10000, CosmosHttpResponseError( status_code=502, message="Some random reverse proxy error.")))) initialized_objects = self.setup_method_with_custom_transport(custom_transport) + start: float = time.perf_counter() try: container: ContainerProxy = initialized_objects["col"] await container.create_item(body=document_definition) pytest.fail("Expected exception not thrown") except CosmosHttpResponseError as cosmosError: + end = time.perf_counter() - start + # validate response took more than 10 seconds + assert end > 10 if cosmosError.status_code != 502: raise cosmosError finally: @@ -298,7 +302,7 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() - # Inject topology transformation that would make Emulator look like a single write region + # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ @@ -337,14 +341,14 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - # add a test for delays # add test for complete failures async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync"): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host custom_transport = FaultInjectionTransportAsync() - # Inject topology transformation that would make Emulator look like a single write region + # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ @@ -356,6 +360,19 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" is_get_account_predicate, emulator_as_multi_write_region_account_transformation) + # Inject rule to simulate regional outage in "First Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 500, + CosmosHttpResponseError( + status_code=408, + message="Induced Request Timeout")))) + id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, From f04a50695dc8ad74f2ce25090639e17a67a4320b Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 17:54:44 -0700 Subject: [PATCH 034/152] add more tests --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index c6e0e5b39b5c..238e10efba3e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -351,7 +351,7 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) - emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + emulator_as_multi_write_region_account_transformation = \ lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( first_region_name="First Region", second_region_name="Second Region", From bcee9cf4888f1e8326ff307454136aba4d8edf54 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 18:24:06 -0700 Subject: [PATCH 035/152] Add tests --- .../test_fault_injection_transport_async.py | 139 ++++++++++++++++-- 1 file changed, 126 insertions(+), 13 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 238e10efba3e..5fcc953df5f2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -22,6 +22,7 @@ from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.core.exceptions import ServiceRequestError MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') @@ -91,7 +92,7 @@ def cleanup_method(initialized_objects: dict[str, Any]): except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") - async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): + async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsync"): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, @@ -121,7 +122,7 @@ async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -172,7 +173,7 @@ async def test_swr_mrr_succeeds(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -232,7 +233,7 @@ async def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransp finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -298,7 +299,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjection TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -343,7 +344,7 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): # add test for complete failures - async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host custom_transport = FaultInjectionTransportAsync() @@ -361,17 +362,13 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" emulator_as_multi_write_region_account_transformation) # Inject rule to simulate regional outage in "First Region" - is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ r: FaultInjectionTransportAsync.predicate_targets_region(r, first_region_uri) and \ FaultInjectionTransportAsync.predicate_is_document_operation(r) custom_transport.add_fault( - is_request_to_read_region, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 500, - CosmosHttpResponseError( - status_code=408, - message="Induced Request Timeout")))) + is_request_to_first_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, @@ -387,6 +384,7 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" start:float = time.perf_counter() while (time.perf_counter() - start) < 2: + # reads and writes should failover to second region upsert_document = await container.upsert_item(body=document_definition) request = upsert_document.get_response_headers()["_request"] assert request.url.startswith(second_region_uri) @@ -398,5 +396,120 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjectionTransportAsync"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransportAsync() + + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri) and + FaultInjectionTransportAsync.predicate_is_document_operation(r) and + not FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_write_region_uri)) + + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransportAsync.predicate_targets_region(r, expected_read_region_uri) and + FaultInjectionTransportAsync.predicate_is_document_operation(r) and + not FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_write_region_uri)) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + await container.upsert_item(body=document_definition) + with pytest.raises(ServiceRequestError): + await container.read_item(id_value, partition_key=id_value) + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsync"): + + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransportAsync() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_targets_region(r, second_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + with pytest.raises(ServiceRequestError): + await container.upsert_item(body=document_definition) + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + if __name__ == '__main__': unittest.main() From 779f9d19b6a8bacca8d3cee2481b4a0cf34c5dcd Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 23:27:56 -0700 Subject: [PATCH 036/152] fix tests --- sdk/cosmos/azure-cosmos/pytest.ini | 5 ++++- .../tests/_fault_injection_transport_async.py | 6 +++--- .../tests/test_fault_injection_transport_async.py | 12 ++++++------ .../azure-cosmos/tests/test_feed_range_async.py | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 647aac1464f8..0ea65741e343 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -1,3 +1,6 @@ [pytest] markers = - cosmosEmulator: marks tests as depending in Cosmos DB Emulator + cosmosEmulator: marks tests as depending in Cosmos DB Emulator. + cosmosLong: marks tests to be run on a Cosmos DB live account. + cosmosQuery: marks tests running queries on Cosmos DB live account. + cosmosSplit: marks test where there are partition splits on CosmosDB live account. diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index f5b9c8b126dd..4551b0235bad 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -221,7 +221,7 @@ async def transform_topology_mwr( return response class MockHttpResponse(AioHttpTransportResponse): - def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, Any]]): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]): self.request: HttpRequest = request # This is actually never None, and set by all implementations after the call to # __init__ of this class. This class is also a legacy impl, so it's risky to change it @@ -232,7 +232,7 @@ def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict self.reason: Optional[str] = None self.content_type: Optional[str] = None self.block_size: int = 4096 # Default to same as R - self.content: Optional[dict[str, Any]] = None + self.content: Optional[Dict[str, Any]] = None self.json_text: Optional[str] = None self.bytes: Optional[bytes] = None if content: @@ -248,4 +248,4 @@ def text(self) -> Optional[str]: return self.json_text async def load_body(self) -> None: - return \ No newline at end of file + return diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 5fcc953df5f2..3f0507acc038 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -8,7 +8,7 @@ import time import unittest import uuid -from typing import Any, Callable +from typing import Any, Callable, Awaitable, Dict import pytest from azure.core.pipeline.transport import AioHttpTransport @@ -24,7 +24,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -MGMT_TIMEOUT = 3.0 +MGMT_TIMEOUT = 10.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -74,7 +74,7 @@ def teardown_class(cls): try: asyncio.run(asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT)) except Exception as closeError: - logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) @staticmethod def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): @@ -85,7 +85,7 @@ def setup_method_with_custom_transport(custom_transport: AioHttpTransport, defau return {"client": client, "db": db, "col": container} @staticmethod - def cleanup_method(initialized_objects: dict[str, Any]): + def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] try: asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) @@ -261,7 +261,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj # Inject topology transformation that would make Emulator look like a single write region # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], Awaitable[AsyncHttpResponse]]], AioHttpTransportResponse] = \ lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", @@ -306,7 +306,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) - emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], Awaitable[AsyncHttpResponse]]], AioHttpTransportResponse] = \ lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( first_region_name="First Region", second_region_name="Second Region", diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py index b540a5a70423..84318f4dc5bb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py @@ -58,7 +58,7 @@ async def test_feed_range_is_subset_from_pk_async(self): True, False)).to_dict() epk_child_feed_range = await self.container_for_test.feed_range_from_partition_key("1") - assert self.container_for_test.is_feed_range_subset(epk_parent_feed_range, epk_child_feed_range) + assert await self.container_for_test.is_feed_range_subset(epk_parent_feed_range, epk_child_feed_range) if __name__ == '__main__': unittest.main() From b4db22e8c7b2cec48c61030ae8dc377297281271 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 23:29:04 -0700 Subject: [PATCH 037/152] fix tests --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 3f0507acc038..b781526a807d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -342,8 +342,6 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - # add test for complete failures - async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host From 9a94d028217a5af4540731d68932a8cd103dc3d2 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 3 Apr 2025 09:01:48 -0700 Subject: [PATCH 038/152] fix tests --- sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py | 2 +- sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index d81884d14e55..873c032d7ead 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -235,7 +235,7 @@ def _get_configured_excluded_locations(self, request: RequestObject): if excluded_locations is None: # If excluded locations were only configured on client(connection_policy), use client level excluded_locations = self.connection_policy.ExcludedLocations - excluded_locations.union(request.excluded_locations_circuit_breaker) + excluded_locations.extend(request.excluded_locations_circuit_breaker) return excluded_locations def _get_applicable_read_regional_endpoints(self, request: RequestObject): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 4f02ab0bb52b..28dc2fefd73b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,7 +21,7 @@ """Represents a request object. """ -from typing import Optional, Mapping, Any, Dict, Set +from typing import Optional, Mapping, Any, Dict, Set, List class RequestObject(object): @@ -41,7 +41,7 @@ def __init__( self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None - self.excluded_locations: Optional[Set[str]] = None + self.excluded_locations: Optional[List[str]] = None self.excluded_locations_circuit_breaker: Set[str] = set() def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long From eab1b63f2b1f9e2260c60cbca4b3f96a514fad3d Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 09:50:10 -0700 Subject: [PATCH 039/152] fix test --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index b781526a807d..40c73c31381b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -24,7 +24,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -MGMT_TIMEOUT = 10.0 +MGMT_TIMEOUT = 1.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) From 93c2d7ddac7bd7f305f0f7b419ae6ab7279455da Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 11:12:03 -0700 Subject: [PATCH 040/152] fix test --- .../test_fault_injection_transport_async.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 40c73c31381b..4f3a78c28760 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -9,6 +9,7 @@ import unittest import uuid from typing import Any, Callable, Awaitable, Dict +from unittest import IsolatedAsyncioTestCase import pytest from azure.core.pipeline.transport import AioHttpTransport @@ -36,9 +37,9 @@ @pytest.mark.cosmosEmulator @pytest.mark.asyncio -class TestFaultInjectionTransportAsync: +class TestFaultInjectionTransportAsync(IsolatedAsyncioTestCase): @classmethod - def setup_class(cls): + async def asyncSetUp(cls): logger.info("starting class: {} execution".format(cls.__name__)) cls.host = host cls.master_key = master_key @@ -54,30 +55,30 @@ def setup_class(cls): cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) - asyncio.run(asyncio.wait_for( + await asyncio.wait_for( created_database.create_container( cls.single_partition_container_name, partition_key=PartitionKey("/pk")), - MGMT_TIMEOUT)) + MGMT_TIMEOUT) @classmethod - def teardown_class(cls): + async def asyncTearDown(cls): logger.info("tearing down class: {}".format(cls.__name__)) created_database = cls.mgmt_client.get_database_client(cls.database_id) try: - asyncio.run(asyncio.wait_for( + await asyncio.wait_for( created_database.delete_container(cls.single_partition_container_name), - MGMT_TIMEOUT)) + MGMT_TIMEOUT) except Exception as containerDeleteError: logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: try: - asyncio.run(asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT)) + await asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT) except Exception as closeError: logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) @staticmethod - def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + async def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) @@ -85,7 +86,7 @@ def setup_method_with_custom_transport(custom_transport: AioHttpTransport, defau return {"client": client, "db": db, "col": container} @staticmethod - def cleanup_method(initialized_objects: Dict[str, Any]): + async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] try: asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) @@ -107,7 +108,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy status_code=502, message="Some random reverse proxy error.")))) - initialized_objects = self.setup_method_with_custom_transport(custom_transport) + initialized_objects = await self.setup_method_with_custom_transport(custom_transport) start: float = time.perf_counter() try: container: ContainerProxy = initialized_objects["col"] @@ -120,7 +121,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy if cosmosError.status_code != 502: raise cosmosError finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -152,7 +153,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["Read Region", "Write Region"]) try: @@ -171,7 +172,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): assert request.url.startswith(expected_read_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -211,7 +212,7 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -231,7 +232,7 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection assert request.url.startswith(expected_write_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -276,7 +277,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -296,7 +297,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj assert request.url.startswith(expected_write_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): @@ -321,7 +322,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -340,7 +341,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): assert request.url.startswith(first_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") @@ -374,7 +375,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -392,7 +393,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport assert request.url.startswith(second_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -446,7 +447,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -455,7 +456,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection with pytest.raises(ServiceRequestError): await container.read_item(id_value, partition_key=id_value) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsync"): @@ -499,7 +500,7 @@ async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsyn 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -507,7 +508,7 @@ async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsyn with pytest.raises(ServiceRequestError): await container.upsert_item(body=document_definition) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) if __name__ == '__main__': unittest.main() From 345f3901a8c79690d698ceb6781a2377591332f6 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 11:46:06 -0700 Subject: [PATCH 041/152] fix tests --- .../tests/test_fault_injection_transport_async.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 4f3a78c28760..f3b56766b4d1 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -25,7 +25,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -MGMT_TIMEOUT = 1.0 +MGMT_TIMEOUT = 5.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -33,7 +33,6 @@ host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID -single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -51,7 +50,7 @@ async def asyncSetUp(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") cls.database_id = TEST_DATABASE_ID - cls.single_partition_container_name = single_partition_container_name + cls.single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) @@ -77,12 +76,11 @@ async def asyncTearDown(cls): except Exception as closeError: logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) - @staticmethod - async def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(single_partition_container_name) + container: ContainerProxy = db.get_container_client(self.single_partition_container_name) return {"client": client, "db": db, "col": container} @staticmethod From fe74aa0ab0fb106d03a995e6ac52dea833e3636c Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 13:43:34 -0700 Subject: [PATCH 042/152] fix async in test --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index f3b56766b4d1..1df1de05936d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -87,7 +87,7 @@ async def setup_method_with_custom_transport(self, custom_transport: AioHttpTran async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] try: - asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) + await asyncio.wait_for(method_client.close(), MGMT_TIMEOUT) except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") From 5bb9f1fb166a5a9803af8287eac54b506da3c8b5 Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Thu, 3 Apr 2025 14:47:32 -0700 Subject: [PATCH 043/152] Added multi-region tests --- sdk/cosmos/live-platform-matrix.json | 37 ++++++++++++++++++++++++++++ sdk/cosmos/test-resources.bicep | 6 +++++ 2 files changed, 43 insertions(+) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 485a15ca92e8..bca59256d05d 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -88,6 +88,43 @@ "TestMarkArgument": "cosmosLong" } } + }, + { + "WindowsConfig": { + "Windows2022_38_multi_region": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.8", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion", + "ArmConfig": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } + }, + "Windows2022_310_multi_region": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.10", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion", + "ArmConfig": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } + }, + "Windows2022_312_multi_region": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.12", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion", + "ArmConfig": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } + } + } } ] } diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 17d88b0be92a..61588a526eed 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -41,6 +41,12 @@ var multiRegionConfiguration = [ failoverPriority: 1 isZoneRedundant: false } + { + locationName: 'West US 2' + provisioningState: 'Succeeded' + failoverPriority: 2 + isZoneRedundant: false + } ] var locationsConfiguration = (enableMultipleRegions ? multiRegionConfiguration : singleRegionConfiguration) var roleDefinitionId = guid(baseName, 'roleDefinitionId') From 996217ae3fec5740cf0bc3eb180d2ba6af725953 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Thu, 3 Apr 2025 15:08:49 -0700 Subject: [PATCH 044/152] Fix _AddParitionKey to pass options to sub methods --- .../azure/cosmos/_cosmos_client_connection.py | 10 +++++++--- .../azure-cosmos/azure/cosmos/_request_object.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index d64da38defb1..acc9ac0010af 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3265,7 +3265,7 @@ def _AddPartitionKey( options: Mapping[str, Any] ) -> Dict[str, Any]: collection_link = base.TrimBeginningAndEndingSlashes(collection_link) - partitionKeyDefinition = self._get_partition_key_definition(collection_link) + partitionKeyDefinition = self._get_partition_key_definition(collection_link, options) new_options = dict(options) # If the collection doesn't have a partition key definition, skip it as it's a legacy collection if partitionKeyDefinition: @@ -3367,7 +3367,11 @@ def _UpdateSessionIfRequired( # update session self.session.update_session(response_result, response_headers) - def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]: + def _get_partition_key_definition( + self, + collection_link: str, + options: Mapping[str, Any] + ) -> Optional[Dict[str, Any]]: partition_key_definition: Optional[Dict[str, Any]] # If the document collection link is present in the cache, then use the cached partitionkey definition if collection_link in self.__container_properties_cache: @@ -3375,7 +3379,7 @@ def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[s partition_key_definition = cached_container.get("partitionKey") # Else read the collection from backend and add it to the cache else: - container = self.ReadContainer(collection_link) + container = self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") self.__container_properties_cache[collection_link] = _set_properties_cache(container) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 94805934ce74..185aa1d89cb8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -57,7 +57,7 @@ def clear_route_to_location(self) -> None: def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: # If resource types for requests are not one of the followings, excluded locations cannot be set - if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey']: + if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey', 'colls']: return False # If 'excludedLocations' wasn't in the options, excluded locations cannot be set From 41fc9176bec2687e41bc80e7f5254763754c0930 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Thu, 3 Apr 2025 15:10:39 -0700 Subject: [PATCH 045/152] Added initial live tests --- .../tests/test_excluded_locations.py | 478 ++++++++++++++++++ 1 file changed, 478 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py new file mode 100644 index 000000000000..01d1e9e9cf7e --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -0,0 +1,478 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest + +import azure.cosmos.cosmos_client as cosmos_client +from azure.cosmos.partition_key import PartitionKey +from azure.cosmos.exceptions import CosmosResourceNotFoundError + + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +# L0 = "Default" +# L1 = "West US 3" +# L2 = "West US" +# L3 = "East US 2" +# L4 = "Central US" + +L0 = "Default" +L1 = "East US 2" +L2 = "East US" +L3 = "West US 2" +L4 = "Central US" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2, L3], [], None], + # 1. Single excluded location + [[L1, L2, L3], [L1], None], + # 2. Multiple excluded locations + [[L1, L2, L3], [L1, L2], None], + # 3. Exclude all locations + [[L1, L2, L3], [L1, L2, L3], None], + # 4. Exclude a location not in preferred locations + [[L1, L2, L3], [L4], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2, L3], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2, L3], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2, L3], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2, L3], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2, L3], [L1, L2, L3], [L1, L2, L3]], + # 5. No common excluded locations + [[L1, L2, L3], [L1], [L2, L3]], + # 6. Reqeust excluded location not in preferred locations + [[L1, L2, L3], [L1, L2, L3], [L4]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2, L3], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA +# ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA +# ALL_INPUT_TEST_DATA = CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L3], # 2 + [L1], # 3 + [L1] # 4 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L3], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1], #0 + [L2, L2, L2], #1 + [L3, L3, L3], #2 + [L1, L1, L1], #3 + [L1, L1, L1] #4 + ] + client_and_request_output_data = [ + [L1, L1, L2], #0 + [L2, L2, L2], #1 + [L3, L3, L2], #2 + [L2, L2, L3], #3 + [L1, L1, L1], #4 + [L2, L2, L1], #5 + [L1, L1, L1], #6 + [L3, L3, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1, L1], #0 + [L2, L2], #1 + [L3, L3], #2 + [L1, L0], #3 + [L1, L1] #4 + ] + client_and_request_output_data = [ + [L2, L2], #0 + [L2, L2], #1 + [L2, L2], #2 + [L3, L3], #3 + [L1, L0], #4 + [L1, L1], #5 + [L1, L1], #6 + [L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L3], #2 + [L0], #3 + [L1] #4 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L3], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + container = cosmos_client.CosmosClient(HOST, KEY).get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + container.create_item(body=TEST_ITEM) + + yield + # Code to run after tests + print("Teardown: This runs after all tests") + +@pytest.mark.cosmosMultiRegion +class TestExcludedLocations: + def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = cosmos_client.CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = client.get_database_client(DATABASE_ID) + container = db.get_container_client(CONTAINER_ID) + MOCK_HANDLER.reset() + + return client, db, container + + def _verify_endpoint(self, client, expected_locations): + # get mapping for locations + location_mapping = (client.client_connection._global_endpoint_manager. + location_cache.account_locations_by_write_regional_routing_context) + default_endpoint = (client.client_connection._global_endpoint_manager. + location_cache.default_regional_routing_context.get_primary()) + + # get Request URL + msgs = MOCK_HANDLER.messages + req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] + + # get location + actual_locations = [] + for req_url in req_urls: + if req_url.startswith(default_endpoint): + actual_locations.append(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.append(location) + break + + assert actual_locations == expected_locations + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_read_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_item + if request_excluded_locations is None: + container.read_item(ITEM_ID, ITEM_PK_VALUE) + else: + container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_read_all_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_all_items + if request_excluded_locations is None: + list(container.read_all_items()) + else: + list(container.read_all_items(excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_query_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items + if request_excluded_locations is None: + list(container.query_items(None)) + else: + list(container.query_items(None, excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) + def test_query_items_change_feed(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items_change_feed + if request_excluded_locations is None: + list(container.query_items_change_feed()) + else: + list(container.query_items_change_feed(excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_replace_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: replace_item + if request_excluded_locations is None: + container.replace_item(ITEM_ID, body=TEST_ITEM) + else: + container.replace_item(ITEM_ID, body=TEST_ITEM, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_upsert_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: upsert_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + container.upsert_item(body=body) + else: + container.upsert_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_create_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: create_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + container.create_item(body=body) + else: + container.create_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_patch_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: patch_item + operations = [ + {"op": "add", "path": "/test_data", "value": f'Data-{str(uuid.uuid4())}'}, + ] + if request_excluded_locations is None: + container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations) + else: + container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_execute_item_batch(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: execute_item_batch + batch_operations = [] + for i in range(3): + batch_operations.append(("create", ({"id": f'Doc-{str(uuid.uuid4())}', PARTITION_KEY: ITEM_PK_VALUE},))) + + if request_excluded_locations is None: + container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE,) + else: + container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + #create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + MOCK_HANDLER.reset() + + # API call: read_item + if request_excluded_locations is None: + container.delete_item(item_id, ITEM_PK_VALUE) + else: + container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) + + # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature + # @pytest.mark.parametrize('test_data', patch_item_test_data()) + # def test_delete_all_items_by_partition_key(self, test_data): + # # Init test variables + # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + # + # for multiple_write_locations in [True, False]: + # # Client setup + # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + # + # #create before delete + # item_id = f'doc2-{str(uuid.uuid4())}' + # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' + # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) + # MOCK_HANDLER.reset() + # + # # API call: read_item + # if request_excluded_locations is None: + # container.delete_all_items_by_partition_key(pk_value) + # else: + # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) + # + # # Verify endpoint locations + # if multiple_write_locations: + # self._verify_endpoint(client, expected_locations) + # else: + # self._verify_endpoint(client, [L1]) + +if __name__ == "__main__": + unittest.main() From 07b8f39aeaba9a68e7656777bb0732655a20bf4a Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Thu, 3 Apr 2025 15:28:38 -0700 Subject: [PATCH 046/152] Updated live-platform-matrix for multi-region tests --- sdk/cosmos/live-platform-matrix.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index bca59256d05d..b3242623be78 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -86,11 +86,7 @@ "CoverageArg": "--disablecov", "TestSamples": "false", "TestMarkArgument": "cosmosLong" - } - } - }, - { - "WindowsConfig": { + }, "Windows2022_38_multi_region": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", @@ -110,7 +106,9 @@ "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion", "ArmConfig": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } } }, "Windows2022_312_multi_region": { @@ -121,7 +119,9 @@ "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion", "ArmConfig": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } } } } From 1b09739533408a2699443a88df1fe7f59e7fa617 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 3 Apr 2025 15:31:56 -0700 Subject: [PATCH 047/152] initial sync version of fault injection --- .../tests/_fault_injection_transport.py | 253 ++++++++++++++++++ .../tests/test_fault_injection_transport.py | 99 +++++++ 2 files changed, 352 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py new file mode 100644 index 000000000000..9cdd18936945 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -0,0 +1,253 @@ +# The MIT License (MIT) +# Copyright (c) 2014 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""AioHttpTransport allowing injection of faults between SDK and Cosmos Gateway +""" + +import json +import logging +import sys +from time import sleep +from typing import Callable, Optional, Any, Dict, List, MutableMapping + +from azure.core.pipeline.transport import HttpRequest, HttpResponse +from azure.core.pipeline.transport._requests_basic import RequestsTransport, RequestsTransportResponse +from requests import Session + +from azure.cosmos import documents + +import test_config +from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.core.exceptions import ServiceRequestError + +class FaultInjectionTransport(RequestsTransport): + logger = logging.getLogger('azure.cosmos.fault_injection_transport') + logger.setLevel(logging.DEBUG) + + def __init__(self, *, session: Optional[Session] = None, loop=None, session_owner: bool = True, **config): + self.faults: List[Dict[str, Any]] = [] + self.requestTransformations: List[Dict[str, Any]] = [] + self.responseTransformations: List[Dict[str, Any]] = [] + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], Exception]): + self.faults.append({"predicate": predicate, "apply": fault_factory}) + + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], RequestsTransportResponse]], RequestsTransportResponse]): + self.responseTransformations.append({ + "predicate": predicate, + "apply": response_transformation}) + + @staticmethod + def __first_item(iterable, condition=lambda x: True): + """ + Returns the first item in the `iterable` that satisfies the `condition`. + + If no item satisfies the condition, it returns None. + """ + return next((x for x in iterable if condition(x)), None) + + def send(self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs) -> HttpResponse: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) + # find the first fault Factory with matching predicate if any + first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) + if first_fault_factory: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") + injected_error = first_fault_factory["apply"](request) + FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error)) + raise injected_error + + # apply the chain of request transformations with matching predicates if any + matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matching_request_transformations: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") + request = currentTransformation["apply"](request) + + first_response_transformation = FaultInjectionTransport.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + get_response_task = super().send(request, proxies=proxies, **kwargs) + FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + + if first_response_transformation: + FaultInjectionTransport.logger.info(f"Invoking response transformation") + response = first_response_transformation["apply"](request, lambda: get_response_task) + response.headers["_request"] = request + FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") + return response + else: + FaultInjectionTransport.logger.info(f"Sending request to {request.url}") + response = get_response_task + response.headers["_request"] = request + FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") + return response + + @staticmethod + def predicate_url_contains_id(r: HttpRequest, id_value: str) -> bool: + return id_value in r.url + + @staticmethod + def predicate_targets_region(r: HttpRequest, region_endpoint: str) -> bool: + return r.url.startswith(region_endpoint) + + @staticmethod + def print_call_stack(): + print("Call stack:") + frame = sys._getframe() + while frame: + print(f"File: {frame.f_code.co_filename}, Line: {frame.f_lineno}, Function: {frame.f_code.co_name}") + frame = frame.f_back + + @staticmethod + def predicate_req_payload_contains_id(r: HttpRequest, id_value: str): + if r.body is None: + return False + + return '"id":"{}"'.format(id_value) in r.body + + @staticmethod + def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool: + return (FaultInjectionTransport.predicate_url_contains_id(r, id_value) + or FaultInjectionTransport.predicate_req_payload_contains_id(r, id_value)) + + @staticmethod + def predicate_is_database_account_call(r: HttpRequest) -> bool: + is_db_account_read = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') + + return is_db_account_read + + @staticmethod + def predicate_is_document_operation(r: HttpRequest) -> bool: + is_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs') + + return is_document_operation + + @staticmethod + def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: + is_write_document_operation = documents._OperationType.IsWriteOperation( + str(r.headers.get('x-ms-thinclient-proxy-operation-type'))) + + return is_write_document_operation and uri_prefix in r.url + + @staticmethod + def error_after_delay(delay_in_ms: int, error: Exception) -> Exception: + sleep(delay_in_ms / 1000.0) + return error + + @staticmethod + def error_write_forbidden() -> Exception: + return CosmosHttpResponseError( + status_code=403, + message="Injected error disallowing writes in this region.", + response=None, + sub_status_code=3, + ) + + @staticmethod + def error_region_down() -> Exception: + return ServiceRequestError( + message="Injected region down.", + ) + + @staticmethod + def transform_topology_swr_mrr( + write_region_name: str, + read_region_name: str, + inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + + response = inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = write_region_name + writable_locations[0]["name"] = write_region_name + readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host}) + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + + @staticmethod + def transform_topology_mwr( + first_region_name: str, + second_region_name: str, + inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + + response = inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = first_region_name + writable_locations[0]["name"] = first_region_name + readable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + writable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + result["enableMultipleWriteLocations"] = True + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + + class MockHttpResponse(RequestsTransportResponse): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]): + self.request: HttpRequest = request + # This is actually never None, and set by all implementations after the call to + # __init__ of this class. This class is also a legacy impl, so it's risky to change it + # for low benefits The new "rest" implementation does define correctly status_code + # as non-optional. + self.status_code: int = status_code + self.headers: MutableMapping[str, str] = {} + self.reason: Optional[str] = None + self.content_type: Optional[str] = None + self.block_size: int = 4096 # Default to same as R + self.content: Optional[Dict[str, Any]] = None + self.json_text: str = "" + self.bytes: bytes = b"" + if content: + self.content = content + self.json_text = json.dumps(content) + self.bytes = self.json_text.encode("utf-8") + + + def body(self) -> bytes: + return self.bytes + + def text(self, encoding: Optional[str] = None) -> str: + return self.json_text + + def load_body(self) -> None: + return diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py new file mode 100644 index 000000000000..291163159696 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -0,0 +1,99 @@ +import logging +import os +import sys +import time +import uuid +from typing import Callable + +import pytest +from azure.core.pipeline.transport._requests_basic import RequestsTransport +from azure.core.rest import HttpRequest + +import test_config +from _fault_injection_transport_async import FaultInjectionTransportAsync +from azure.cosmos import PartitionKey +from azure.cosmos import CosmosClient +from azure.cosmos.container import ContainerProxy +from azure.cosmos.database import DatabaseProxy +from azure.cosmos.exceptions import CosmosHttpResponseError + +logger = logging.getLogger('azure.cosmos') +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +host = test_config.TestConfig.host +master_key = test_config.TestConfig.masterKey +TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID +SINGLE_PARTITION_CONTAINER_NAME = os.path.basename(__file__) + str(uuid.uuid4()) + +@pytest.mark.unittest +@pytest.mark.cosmosEmulator +class TestFaultInjectionTransport: + + @classmethod + async def setup_class(cls): + logger.info("starting class: {} execution".format(cls.__name__)) + cls.host = host + cls.master_key = master_key + + if (cls.master_key == '[YOUR_KEY_HERE]' or + cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + cls.database_id = TEST_DATABASE_ID + cls.single_partition_container_name = SINGLE_PARTITION_CONTAINER_NAME + + cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) + created_database = cls.mgmt_client.get_database_client(cls.database_id) + created_database.create_container( + cls.single_partition_container_name, + partition_key=PartitionKey("/pk")) + + + @classmethod + async def teardown_class(cls): + logger.info("tearing down class: {}".format(cls.__name__)) + created_database = cls.mgmt_client.get_database_client(cls.database_id) + try: + created_database.delete_container(cls.single_partition_container_name), + except Exception as containerDeleteError: + logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + + @staticmethod + def setup_method_with_custom_transport(custom_transport: RequestsTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) + db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) + container: ContainerProxy = db.get_container_client(SINGLE_PARTITION_CONTAINER_NAME) + return {"client": client, "db": db, "col": container} + + + def test_throws_injected_error_async(self: "TestFaultInjectionTransport"): + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + custom_transport = FaultInjectionTransportAsync() + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: FaultInjectionTransportAsync.error_after_delay( + 10000, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error."))) + + initialized_objects = TestFaultInjectionTransport.setup_method_with_custom_transport(custom_transport) + start: float = time.perf_counter() + try: + container: ContainerProxy = initialized_objects["col"] + container.create_item(body=document_definition) + pytest.fail("Expected exception not thrown") + except CosmosHttpResponseError as cosmosError: + end = time.perf_counter() - start + # validate response took more than 10 seconds + assert end > 10 + if cosmosError.status_code != 502: + raise cosmosError \ No newline at end of file From 2fb3dc93c455cf90c9008e3a670c85068f4b9e8c Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 15:59:57 -0700 Subject: [PATCH 048/152] add all sync tests --- .../tests/test_fault_injection_transport.py | 392 +++++++++++++++++- 1 file changed, 381 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py index 291163159696..ef1f28b14adc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -2,6 +2,7 @@ import os import sys import time +import unittest import uuid from typing import Callable @@ -10,12 +11,13 @@ from azure.core.rest import HttpRequest import test_config -from _fault_injection_transport_async import FaultInjectionTransportAsync from azure.cosmos import PartitionKey from azure.cosmos import CosmosClient from azure.cosmos.container import ContainerProxy from azure.cosmos.database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError +from tests._fault_injection_transport import FaultInjectionTransport +from azure.core.exceptions import ServiceRequestError logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) @@ -31,7 +33,7 @@ class TestFaultInjectionTransport: @classmethod - async def setup_class(cls): + def setup_class(cls): logger.info("starting class: {} execution".format(cls.__name__)) cls.host = host cls.master_key = master_key @@ -47,13 +49,11 @@ async def setup_class(cls): cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) - created_database.create_container( - cls.single_partition_container_name, - partition_key=PartitionKey("/pk")) + created_database.create_container(cls.single_partition_container_name, partition_key=PartitionKey("/pk")) @classmethod - async def teardown_class(cls): + def teardown_class(cls): logger.info("tearing down class: {}".format(cls.__name__)) created_database = cls.mgmt_client.get_database_client(cls.database_id) try: @@ -70,16 +70,16 @@ def setup_method_with_custom_transport(custom_transport: RequestsTransport, defa return {"client": client, "db": db, "col": container} - def test_throws_injected_error_async(self: "TestFaultInjectionTransport"): + def test_throws_injected_error(self: "TestFaultInjectionTransport"): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, 'name': 'sample document', 'key': 'value'} - custom_transport = FaultInjectionTransportAsync() - predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) - custom_transport.add_fault(predicate, lambda r: FaultInjectionTransportAsync.error_after_delay( + custom_transport = FaultInjectionTransport() + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: FaultInjectionTransport.error_after_delay( 10000, CosmosHttpResponseError( status_code=502, @@ -96,4 +96,374 @@ def test_throws_injected_error_async(self: "TestFaultInjectionTransport"): # validate response took more than 10 seconds assert end > 10 if cosmosError.status_code != 502: - raise cosmosError \ No newline at end of file + raise cosmosError + + def test_swr_mrr_succeeds(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["Read Region", "Write Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" (the write region) + assert request.url.startswith(expected_write_region_uri) + start: float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + + + def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: FaultInjectionTransport.error_region_down()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + + def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: FaultInjectionTransport.error_after_delay( + 500, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error."))) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + + + def test_mwr_succeeds(self: "TestFaultInjectionTransport"): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(first_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(first_region_uri) + + + def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransport"): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: FaultInjectionTransport.error_region_down()) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + + start:float = time.perf_counter() + while (time.perf_counter() - start) < 2: + # reads and writes should failover to second region + upsert_document = container.upsert_item(body=document_definition) + request = upsert_document.get_response_headers()["_request"] + assert request.url.startswith(second_region_uri) + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(second_region_uri) + + + def test_swr_mrr_all_regions_down_for_read(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransport.predicate_targets_region(r, expected_write_region_uri) and + FaultInjectionTransport.predicate_is_document_operation(r) and + not FaultInjectionTransport.predicate_is_write_operation(r, expected_write_region_uri)) + + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and + FaultInjectionTransport.predicate_is_document_operation(r) and + not FaultInjectionTransport.predicate_is_write_operation(r, expected_write_region_uri)) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: FaultInjectionTransport.error_region_down()) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: FaultInjectionTransport.error_region_down()) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + container.upsert_item(body=document_definition) + with pytest.raises(ServiceRequestError): + container.read_item(id_value, partition_key=id_value) + + def test_mwr_all_regions_down(self: "TestFaultInjectionTransport"): + + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, second_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: FaultInjectionTransport.error_region_down()) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: FaultInjectionTransport.error_region_down()) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + with pytest.raises(ServiceRequestError): + container.upsert_item(body=document_definition) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 7b81482b7bd282883a3aefb0c5cd0674484dc203 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 16:08:49 -0700 Subject: [PATCH 049/152] add new error and fix logs --- .../tests/_fault_injection_transport.py | 10 ++++++++-- .../tests/_fault_injection_transport_async.py | 20 ++++++++++++------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 9cdd18936945..628456d95158 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -19,7 +19,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -"""AioHttpTransport allowing injection of faults between SDK and Cosmos Gateway +"""RequestTransport allowing injection of faults between SDK and Cosmos Gateway """ import json @@ -36,7 +36,7 @@ import test_config from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.core.exceptions import ServiceRequestError +from azure.core.exceptions import ServiceRequestError, ServiceResponseError class FaultInjectionTransport(RequestsTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') @@ -168,6 +168,12 @@ def error_region_down() -> Exception: message="Injected region down.", ) + @staticmethod + def error_service_response() -> Exception: + return ServiceResponseError( + message="Injected Service Response Error.", + ) + @staticmethod def transform_topology_swr_mrr( write_region_name: str, diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 4551b0235bad..13dda0dc7e20 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -34,10 +34,10 @@ import test_config from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.core.exceptions import ServiceRequestError +from azure.core.exceptions import ServiceRequestError, ServiceResponseError class FaultInjectionTransportAsync(AioHttpTransport): - logger = logging.getLogger('azure.cosmos.fault_injection_transport') + logger = logging.getLogger('azure.cosmos.fault_injection_transport_async') logger.setLevel(logging.DEBUG) def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner: bool = True, **config): @@ -64,11 +64,11 @@ def __first_item(iterable, condition=lambda x: True): return next((x for x in iterable if condition(x)), None) async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Optional[MutableMapping[str, str]] = None, **config) -> AsyncHttpResponse: - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any first_fault_factory = FaultInjectionTransportAsync.__first_item(iter(self.faults), lambda f: f["predicate"](request)) if first_fault_factory: - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.ApplyFaultInjection") injected_error = await first_fault_factory["apply"](request) FaultInjectionTransportAsync.logger.info("Found to-be-injected error {}".format(injected_error)) raise injected_error @@ -76,14 +76,14 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Opt # apply the chain of request transformations with matching predicates if any matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) for currentTransformation in matching_request_transformations: - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.ApplyRequestTransformation") request = await currentTransformation["apply"](request) first_response_transformation = FaultInjectionTransportAsync.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.BeforeGetResponseTask") get_response_task = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) - FaultInjectionTransportAsync.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + FaultInjectionTransportAsync.logger.info("<-- FaultInjectionTransportAsync.AfterGetResponseTask") if first_response_transformation: FaultInjectionTransportAsync.logger.info(f"Invoking response transformation") @@ -166,6 +166,12 @@ async def error_region_down() -> Exception: message="Injected region down.", ) + @staticmethod + async def error_service_response() -> Exception: + return ServiceResponseError( + message="Injected Service Response Error.", + ) + @staticmethod async def transform_topology_swr_mrr( write_region_name: str, From f355e306d4c998b9e897f24a6d21b881da1bb730 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 16:28:45 -0700 Subject: [PATCH 050/152] fix test --- sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py index ef1f28b14adc..304fa8d50f0d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -16,7 +16,7 @@ from azure.cosmos.container import ContainerProxy from azure.cosmos.database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError -from tests._fault_injection_transport import FaultInjectionTransport +from _fault_injection_transport import FaultInjectionTransport from azure.core.exceptions import ServiceRequestError logger = logging.getLogger('azure.cosmos') From 8495c5139742060f74301dd0441c8e5b4fff787a Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 10:34:42 -0700 Subject: [PATCH 051/152] Add cosmosQuery mark to TestQuery --- sdk/cosmos/azure-cosmos/tests/test_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 28262aa0f7e3..2a99263ed457 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -17,7 +17,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey - +@pytest.mark.cosmosQuery class TestQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" From b29980c0ed0feda7b3fb04adb5d4560bd813ccd8 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 10:35:05 -0700 Subject: [PATCH 052/152] Correct spelling --- sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 01d1e9e9cf7e..7a93e6e89ec7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -74,7 +74,7 @@ def emit(self, record): [[L1, L2, L3], [L1, L2, L3], [L1, L2, L3]], # 5. No common excluded locations [[L1, L2, L3], [L1], [L2, L3]], - # 6. Reqeust excluded location not in preferred locations + # 6. Request excluded location not in preferred locations [[L1, L2, L3], [L1, L2, L3], [L4]], # 7. Empty excluded locations, remove all client excluded locations [[L1, L2, L3], [L1, L2], []], From 5e79172c2da25361e80a098222b712542b1b19ea Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 10:35:37 -0700 Subject: [PATCH 053/152] Fixed live platform matrix syntax --- sdk/cosmos/live-platform-matrix.json | 44 ++++++++-------------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index b3242623be78..494c5fc62cea 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -86,43 +86,25 @@ "CoverageArg": "--disablecov", "TestSamples": "false", "TestMarkArgument": "cosmosLong" - }, + } + } + }, + { + "DESIRED_CONSISTENCIES": "[\"Session\"]", + "ACCOUNT_CONSISTENCY": "Session", + "ArmConfig": { + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" + } + }, + "WindowsConfig": { "Windows2022_38_multi_region": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", "PythonVersion": "3.8", "CoverageArg": "--disablecov", "TestSamples": "false", - "TestMarkArgument": "cosmosMultiRegion", - "ArmConfig": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" - } - }, - "Windows2022_310_multi_region": { - "OSVmImage": "env:WINDOWSVMIMAGE", - "Pool": "env:WINDOWSPOOL", - "PythonVersion": "3.10", - "CoverageArg": "--disablecov", - "TestSamples": "false", - "TestMarkArgument": "cosmosMultiRegion", - "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" - } - } - }, - "Windows2022_312_multi_region": { - "OSVmImage": "env:WINDOWSVMIMAGE", - "Pool": "env:WINDOWSPOOL", - "PythonVersion": "3.12", - "CoverageArg": "--disablecov", - "TestSamples": "false", - "TestMarkArgument": "cosmosMultiRegion", - "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" - } - } + "TestMarkArgument": "cosmosMultiRegion" } } } From fd40cd724873ad9ab46520a18304a5a900fde7b1 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 11:42:08 -0700 Subject: [PATCH 054/152] Changed Multi-regions --- .../tests/test_excluded_locations.py | 18 +++++++++--------- sdk/cosmos/live-platform-matrix.json | 6 ++---- sdk/cosmos/test-resources.bicep | 6 +++--- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 7a93e6e89ec7..2d17bad85ba5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -34,18 +34,18 @@ def emit(self, record): ITEM_PK_VALUE = 'pk' TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} -# L0 = "Default" -# L1 = "West US 3" -# L2 = "West US" -# L3 = "East US 2" -# L4 = "Central US" - L0 = "Default" -L1 = "East US 2" -L2 = "East US" -L3 = "West US 2" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" L4 = "Central US" +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" +# L4 = "Central US" + CLIENT_ONLY_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No excluded location diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 494c5fc62cea..7a02486a8827 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -90,18 +90,16 @@ } }, { - "DESIRED_CONSISTENCIES": "[\"Session\"]", - "ACCOUNT_CONSISTENCY": "Session", "ArmConfig": { "MultiMaster_MultiRegion": { "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" } }, "WindowsConfig": { - "Windows2022_38_multi_region": { + "Windows2022_312": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", - "PythonVersion": "3.8", + "PythonVersion": "3.12", "CoverageArg": "--disablecov", "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion" diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 61588a526eed..b05dead26737 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -30,19 +30,19 @@ var singleRegionConfiguration = [ ] var multiRegionConfiguration = [ { - locationName: 'East US 2' + locationName: 'West US 3' provisioningState: 'Succeeded' failoverPriority: 0 isZoneRedundant: false } { - locationName: 'East US' + locationName: 'West US' provisioningState: 'Succeeded' failoverPriority: 1 isZoneRedundant: false } { - locationName: 'West US 2' + locationName: 'East US 2' provisioningState: 'Succeeded' failoverPriority: 2 isZoneRedundant: false From 85e1206f7771cb0f7a82e7460ae9ca4b247f5712 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Fri, 4 Apr 2025 14:56:01 -0700 Subject: [PATCH 055/152] first ppcb test --- .../azure-cosmos/azure/cosmos/_constants.py | 2 +- ...tition_endpoint_manager_circuit_breaker.py | 1 + ...n_endpoint_manager_circuit_breaker_core.py | 15 +- .../azure/cosmos/_partition_health_tracker.py | 2 +- .../tests/test_ppcb_sm_mrr_async.py | 159 ++++++++++++++++++ 5 files changed, 172 insertions(+), 7 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index cf029179f1a1..38848f3a5d72 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -53,7 +53,7 @@ class _Constants: CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE" CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT: int = 5 FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED" - FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 70 + FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 90 STALE_PARTITION_UNAVAILABILITY_CHECK = "AZURE_COSMOS_STALE_PARTITION_UNAVAILABILITY_CHECK_IN_SECONDS" STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT: int = 120 # ------------------------------------------------------------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index b95ef0a2a7da..3f38be706e73 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -57,6 +57,7 @@ def record_failure( self.global_partition_endpoint_manager_core.record_failure(request) def resolve_service_endpoint(self, request): + # TODO: @tvaron3 check here if it is healthy tentative and move it back to Unhealthy request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 142c51a1ed19..3f60391db4e5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -81,10 +81,11 @@ def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWr # get the partition key range for the given partition key target_container_link = None for container_link, properties in self.client._container_properties_cache: - # TODO: @tvaron3 consider moving this to a constant with other usages if properties["_rid"] == container_rid: target_container_link = container_link - # throw exception if it is not found + if target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) return PartitionKeyRangeWrapper(pkrange, container_rid) @@ -98,7 +99,11 @@ def record_failure( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, str(location)) + + # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering + # TODO: @tvaron3 exponential backoff for recovering + def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: if self.is_circuit_breaker_applicable(request): @@ -112,7 +117,7 @@ def mark_partition_unavailable(self, request: RequestObject) -> None: """ Mark the partition unavailable from the given request. """ - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) pkrange_wrapper = self._create_pkrange_wrapper(request) self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) @@ -124,7 +129,7 @@ def record_success( #convert operation_type to either Read or Write endpoint_operation_type = EndpointOperationType.WriteType if ( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) pkrange_wrapper = self._create_pkrange_wrapper(request) self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 74665a5e7eb5..9fc64dbb5f63 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -168,7 +168,7 @@ def add_failure( self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, - location: Optional[str] + location: str ) -> None: # Retrieve the failure rate threshold from the environment. failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py new file mode 100644 index 000000000000..0a42e0469364 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -0,0 +1,159 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +import pytest_asyncio +from azure.core.pipeline.transport._aiohttp import AioHttpTransport + +from azure.cosmos import PartitionKey +from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from tests import test_config +from tests._fault_injection_transport_async import FaultInjectionTransportAsync + +COLLECTION = "created_collection" +@pytest_asyncio.fixture(scope='class') +async def setup(): + if (TestPPCBSmMrrAsync.master_key == '[YOUR_KEY_HERE]' or + TestPPCBSmMrrAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPPCBSmMrrAsync.host, TestPPCBSmMrrAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), throughput=10000) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + + + + +def error_codes(): + + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPPCBSmMrrAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + async def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + await method_client.close() + + async def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransportAsync() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_consecutive_failure_threshold_async(self, setup, error_code): + expected_read_region_uri = self.host + expected_write_region_uri = self.host.replace("localhost", "127.0.0.1") + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(CosmosHttpResponseError( + status_code=error_code, + message="Some injected fault."))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + with pytest.raises(CosmosHttpResponseError): + await container.create_item(body=document_definition) + + TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + + + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should failover and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + + for i in range(10): + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + + # the partition should have been marked as unavailable + TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 1) + + + @staticmethod + def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pkrange_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info: + if health_info[HEALTH_STATUS] == UNHEALTHY_TENTATIVE or health_info[HEALTH_STATUS] == UNHEALTHY: + unhealthy_partitions += 1 + assert len(health_info_map) == expected_unhealthy_partitions + assert unhealthy_partitions == expected_unhealthy_partitions + + + + + + + # test_failure_rate_threshold + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 34e3d82ba9d2d2cceafd538e455e96b4b5fccf58 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Mon, 7 Apr 2025 10:54:42 -0400 Subject: [PATCH 056/152] fix test --- ...n_endpoint_manager_circuit_breaker_core.py | 20 ++--- .../azure/cosmos/_location_cache.py | 3 +- .../azure/cosmos/_routing/routing_range.py | 3 + .../azure/cosmos/aio/_retry_utility_async.py | 8 +- .../tests/test_ppcb_sm_mrr_async.py | 80 +++++++++---------- 5 files changed, 58 insertions(+), 56 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 3f60391db4e5..014809cac5b4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -67,12 +67,12 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: if request.resource_type != ResourceType.Document: return False - if request.operation_type != documents._OperationType.QueryPlan: + if request.operation_type == documents._OperationType.QueryPlan: return False return True - def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + def _create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: """ Create a PartitionKeyRangeWrapper object. """ @@ -80,14 +80,14 @@ def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWr partition_key = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key target_container_link = None - for container_link, properties in self.client._container_properties_cache: + for container_link, properties in self.client._container_properties_cache.items(): if properties["_rid"] == container_rid: target_container_link = container_link - if target_container_link: + if not target_container_link: raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways - pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pkrange, container_rid) + pk_range = await self.client._routing_map_provider.get_overlapping_ranges(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pk_range, container_rid) def record_failure( self, @@ -98,7 +98,7 @@ def record_failure( endpoint_operation_type = EndpointOperationType.WriteType if ( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, str(location)) # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering @@ -107,7 +107,7 @@ def record_failure( def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) request.set_excluded_locations_from_circuit_breaker( self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) ) @@ -118,7 +118,7 @@ def mark_partition_unavailable(self, request: RequestObject) -> None: Mark the partition unavailable from the given request. """ location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) def record_success( @@ -130,7 +130,7 @@ def record_success( endpoint_operation_type = EndpointOperationType.WriteType if ( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) # TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 873c032d7ead..cb9ef0840a3c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -201,8 +201,7 @@ def get_read_regional_routing_contexts(self): return self.read_regional_routing_contexts def get_location_from_endpoint(self, endpoint: str) -> str: - regional_routing_context = RegionalRoutingContext(endpoint, endpoint) - return self.account_locations_by_read_regional_endpoints[regional_routing_context] + return self.account_locations_by_read_regional_routing_context[endpoint] def get_write_regional_routing_context(self): return self.get_write_regional_routing_contexts()[0].get_primary() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index 4e3d603ef0d8..21a22ca89f61 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -248,3 +248,6 @@ def __eq__(self, other): return False return self.partition_key_range == other.partition_key_range and self.collection_rid == other.collection_rid + def __hash__(self): + return hash((self.partition_key_range, self.collection_rid)) + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index c5613530994d..e1094736b88b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -103,9 +103,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg try: if args: result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) + global_endpoint_manager.record_success(args[0]) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) - global_endpoint_manager.record_success(request) if not client.last_response_headers: client.last_response_headers = {} @@ -200,7 +200,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - global_endpoint_manager.record_failure(request) + global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -258,8 +258,8 @@ async def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - request_params = request.context.options.get('request_params', None) - global_endpoint_manager = request.context.options.get('global_endpoint_manager', None) + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 1360ae5630ba..a8e39ae361da 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -9,6 +9,7 @@ import pytest import pytest_asyncio from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError from azure.cosmos import PartitionKey from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE @@ -18,19 +19,16 @@ from tests._fault_injection_transport_async import FaultInjectionTransportAsync COLLECTION = "created_collection" -@pytest_asyncio.fixture(scope='class') +@pytest_asyncio.fixture() async def setup(): - if (TestPPCBSmMrrAsync.master_key == '[YOUR_KEY_HERE]' or - TestPPCBSmMrrAsync.host == '[YOUR_ENDPOINT_HERE]'): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPPCBSmMrrAsync.host, TestPPCBSmMrrAsync.master_key, consistency_level="Session") created_database = client.get_database_client(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + # print(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + await client.create_database_if_not_exists(TestPPCBSmMrrAsync.TEST_DATABASE_ID) created_collection = await created_database.create_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), throughput=10000) + partition_key=PartitionKey("/pk"), + offer_throughput=10000) yield { COLLECTION: created_collection } @@ -39,13 +37,15 @@ async def setup(): await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" - - - -def error_codes(): - - return [408, 500, 502, 503] - +def errors(): + errors_list = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors_list.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected fault.")) + errors_list.append(ServiceResponseError(message="Injected Service Response Error.")) + return errors_list @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -59,6 +59,7 @@ class TestPPCBSmMrrAsync: async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], transport=custom_transport, **kwargs) db = client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) @@ -92,45 +93,48 @@ async def create_custom_transport_sm_mrr(self): emulator_as_multi_region_sm_account_transformation) return custom_transport - - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_consecutive_failure_threshold_async(self, setup, error_code): + @pytest.mark.parametrize("error", errors()) + async def test_consecutive_failure_threshold_async(self, setup, error): + expected_read_region_uri = self.host + expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = await self.create_custom_transport_sm_mrr() id_value = 'failoverDoc-' + str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': 'pk1', 'name': 'sample document', 'key': 'value'} - predicate = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) - custom_transport.add_fault(predicate, lambda r: asyncio.create_task(CosmosHttpResponseError( - status_code=error_code, - message="Some injected fault."))) + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) container = custom_setup['col'] # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): - with pytest.raises(CosmosHttpResponseError): + with pytest.raises(CosmosHttpResponseError or ServiceResponseError): await container.create_item(body=document_definition) TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) - - # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) - # reads should failover and only the relevant partition should be marked as unavailable + # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) - for i in range(10): - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + for i in range(11): + read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) - # the partition should have been marked as unavailable + # the partition should have been marked as unavailable after breaking read threshold TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 1) @@ -139,19 +143,15 @@ def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_pa health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pkrange_wrapper_to_health_info unhealthy_partitions = 0 for pk_range_wrapper, location_to_health_info in health_info_map.items(): - for location, health_info in location_to_health_info: - if health_info[HEALTH_STATUS] == UNHEALTHY_TENTATIVE or health_info[HEALTH_STATUS] == UNHEALTHY: + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: unhealthy_partitions += 1 - assert len(health_info_map) == expected_unhealthy_partitions assert unhealthy_partitions == expected_unhealthy_partitions - - - - - # test_failure_rate_threshold - + # test_failure_rate_threshold - add service response error + # test service request marks only a partition unavailable not an entire region if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From ce1466618076f815b2312a376c2539856b1ff0dc Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 7 Apr 2025 18:05:33 -0400 Subject: [PATCH 057/152] refactor due to pk range wrapper needing io call and pylint --- .../azure/cosmos/_cosmos_client_connection.py | 4 +- .../azure/cosmos/_global_endpoint_manager.py | 8 +- ...tition_endpoint_manager_circuit_breaker.py | 54 +++++++++---- ...n_endpoint_manager_circuit_breaker_core.py | 81 +++++++------------ .../azure/cosmos/_location_cache.py | 2 +- .../azure/cosmos/_partition_health_tracker.py | 37 ++++----- .../azure/cosmos/_request_object.py | 4 +- .../azure/cosmos/_retry_utility.py | 28 ++++--- .../azure/cosmos/_routing/routing_range.py | 1 - .../cosmos/_service_request_retry_policy.py | 9 ++- .../cosmos/_service_response_retry_policy.py | 8 +- .../azure/cosmos/_session_retry_policy.py | 12 ++- .../azure/cosmos/_synchronized_request.py | 12 ++- .../cosmos/_timeout_failover_retry_policy.py | 8 +- .../azure/cosmos/aio/_asynchronous_request.py | 8 +- .../aio/_cosmos_client_connection_async.py | 3 +- .../aio/_global_endpoint_manager_async.py | 8 +- ..._endpoint_manager_circuit_breaker_async.py | 62 +++++++++----- .../azure/cosmos/aio/_retry_utility_async.py | 20 ++--- 19 files changed, 215 insertions(+), 154 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 298d3032877b..5e9847c9ad93 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -48,7 +48,7 @@ HttpResponse # pylint: disable=no-legacy-azure-core-http-response-import from . import _base as base -from . import _global_endpoint_manager as global_endpoint_manager +from ._global_partition_endpoint_manager_circuit_breaker import _GlobalPartitionEndpointManagerForCircuitBreaker from . import _query_iterable as query_iterable from . import _runtime_constants as runtime_constants from . import _session @@ -164,7 +164,7 @@ def __init__( # pylint: disable=too-many-statements self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 8cf9d06d5486..62dd60a30da0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -30,6 +30,8 @@ from . import _constants as constants from . import exceptions +from ._request_object import RequestObject +from ._routing.routing_range import PartitionKeyRangeWrapper from .documents import DatabaseAccount from ._location_cache import LocationCache, current_time_millis @@ -67,7 +69,11 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def resolve_service_endpoint( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 3f38be706e73..288205cb2411 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -28,6 +28,9 @@ from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos.http_constants import HttpHeaders + if TYPE_CHECKING: from azure.cosmos._cosmos_client_connection import CosmosClientConnection @@ -42,34 +45,55 @@ class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) - self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: - """ - Check if circuit breaker is applicable for a request. - """ return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) + + def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.Client._container_properties_cache.items(): # pylint: disable=protected-access + if properties["_rid"] == container_rid: + target_container_link = container_link + if not target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways + pk_range = (self.Client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(target_container_link, partition_key)) + return PartitionKeyRangeWrapper(pk_range, container_rid) + def record_failure( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_failure(request) + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) - def resolve_service_endpoint(self, request): + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> str: # TODO: @tvaron3 check here if it is healthy tentative and move it back to Unhealthy - request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) - return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) + if self.is_circuit_breaker_applicable(request): + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request, + pk_range_wrapper) - def mark_partition_unavailable(self, request: RequestObject) -> None: - """ - Mark the partition unavailable from the given request. - """ - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) + def mark_partition_unavailable( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) def record_success( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_success(request) - + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 014809cac5b4..1f900215f08d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -30,7 +30,7 @@ from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import EndpointOperationType, LocationCache from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos.http_constants import ResourceType from azure.cosmos._constants import _Constants as Constants @@ -49,9 +49,6 @@ def __init__(self, client, location_cache: LocationCache): self.client = client def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: - """ - Check if circuit breaker is applicable for a request. - """ if not request: return False @@ -61,76 +58,54 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: return False if (not self.location_cache.can_use_multiple_write_locations_for_request(request) - and documents._OperationType.IsWriteOperation(request.operation_type)): + and documents._OperationType.IsWriteOperation(request.operation_type)): # pylint: disable=protected-access return False if request.resource_type != ResourceType.Document: return False - if request.operation_type == documents._OperationType.QueryPlan: + if request.operation_type == documents._OperationType.QueryPlan: # pylint: disable=protected-access return False return True - def _create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - """ - Create a PartitionKeyRangeWrapper object. - """ - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - if not target_container_link: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") - # TODO: @tvaron3 check different clients and create them in different ways - pk_range = await self.client._routing_map_provider.get_overlapping_ranges(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pk_range, container_rid) - def record_failure( self, - request: RequestObject + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to EndpointOperationType - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pk_range_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, str(location)) - + #convert operation_type to EndpointOperationType + endpoint_operation_type = (EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) # pylint: disable=protected-access + else EndpointOperationType.ReadType) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_failure(pk_range_wrapper, endpoint_operation_type, str(location)) # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering # TODO: @tvaron3 exponential backoff for recovering - - def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: - if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pk_range_wrapper(request) - request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) - ) + def add_excluded_locations_to_request( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> RequestObject: + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pk_range_wrapper) + ) return request - def mark_partition_unavailable(self, request: RequestObject) -> None: - """ - Mark the partition unavailable from the given request. - """ + def mark_partition_unavailable(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pk_range_wrapper(request) - self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + self.partition_health_tracker.mark_partition_unavailable(pk_range_wrapper, location) def record_success( self, - request: RequestObject + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to either Read or Write - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pk_range_wrapper(request) - self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType # pylint: disable=protected-access + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_success(pk_range_wrapper, endpoint_operation_type, location) # TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index cb9ef0840a3c..1d47ef51b5e0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -25,7 +25,7 @@ import collections import logging import time -from typing import Set, Mapping, List, Optional +from typing import Set, Mapping, List from urllib.parse import urlparse from . import documents diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 9fc64dbb5f63..04849ca1cb4b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -22,10 +22,10 @@ """Internal class for partition health tracker for circuit breaker. """ import os -from typing import Dict, Set, Any, Optional -from ._constants import _Constants as Constants +from typing import Dict, Set, Any +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import current_time_millis, EndpointOperationType -from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range +from ._constants import _Constants as Constants MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 @@ -113,9 +113,10 @@ def _transition_health_status_on_failure( if location in region_to_partition_health: # healthy tentative -> unhealthy # if the operation type is not empty, we are in the healthy tentative state - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = UNHEALTHY + region_to_partition_health[location].unavailability_info[HEALTH_STATUS] = UNHEALTHY # reset the last unavailability check time stamp - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] = UNHEALTHY + region_to_partition_health[location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ + = UNHEALTHY else: # healthy -> unhealthy tentative # if the operation type is empty, we are in the unhealthy tentative state @@ -135,21 +136,22 @@ def _transition_health_status_on_success( # healthy tentative -> healthy self.pkrange_wrapper_to_health_info[pkrange_wrapper].pop(location, None) - def _check_stale_partition_info(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> None: + def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: current_time = current_time_millis() - stale_partition_unavailability_check = int(os.getenv(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, + stale_partition_unavailability_check = int(os.environ.get(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 - if pkrange_wrapper in self.pkrange_wrapper_to_health_info: - for location, partition_health_info in self.pkrange_wrapper_to_health_info[pkrange_wrapper].items(): - elapsed_time = current_time - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] + if pk_range_wrapper in self.pkrange_wrapper_to_health_info: + for _, partition_health_info in self.pkrange_wrapper_to_health_info[pk_range_wrapper].items(): + elapsed_time = (current_time - + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] # check if the partition key range is still unavailable if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) or (current_health_status == UNHEALTHY_TENTATIVE and elapsed_time > INITIAL_UNAVAILABLE_TIME)): # unhealthy or unhealthy tentative -> healthy tentative - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE if current_time - self.last_refresh < REFRESH_INTERVAL: # all partition stats reset every minute @@ -160,8 +162,7 @@ def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> S self._check_stale_partition_info(pkrange_wrapper) if pkrange_wrapper in self.pkrange_wrapper_to_health_info: return set(self.pkrange_wrapper_to_health_info[pkrange_wrapper].keys()) - else: - return set() + return set() def add_failure( @@ -171,7 +172,7 @@ def add_failure( location: str ) -> None: # Retrieve the failure rate threshold from the environment. - failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, + failure_rate_threshold = int(os.environ.get(Constants.FAILURE_PERCENTAGE_TOLERATED, Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) # Ensure that the health info dictionary is properly initialized. @@ -201,7 +202,7 @@ def add_failure( setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) # Retrieve the consecutive failure threshold from the environment. - consecutive_failure_threshold = int(os.getenv(env_key, default_consec_threshold)) + consecutive_failure_threshold = int(os.environ.get(env_key, default_consec_threshold)) # Call the threshold checker with the current stats. self._check_thresholds( @@ -256,6 +257,6 @@ def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: def _reset_partition_health_tracker_stats(self) -> None: - for pkrange_wrapper in self.pkrange_wrapper_to_health_info: - for location in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].reset_health_stats() + for locations in self.pkrange_wrapper_to_health_info.values(): + for health_info in locations.values(): + health_info.reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 28dc2fefd73b..dace40aba2fb 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -24,7 +24,7 @@ from typing import Optional, Mapping, Any, Dict, Set, List -class RequestObject(object): +class RequestObject(object): # pylint: disable=too-many-instance-attributes def __init__( self, resource_type: str, @@ -84,5 +84,5 @@ def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None if self._can_set_excluded_location(options): self.excluded_locations = options['excludedLocations'] - def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: # pylint: disable=name-too-long self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 927ed7a41baa..44c6b088696e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -45,7 +45,7 @@ # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches -def Execute(client, global_endpoint_manager, function, *args, **kwargs): +def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylint: disable=too-many-locals """Executes the function with passed parameters applying all retry policies :param object client: @@ -58,6 +58,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -73,19 +76,19 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -104,6 +107,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): try: if args: result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs) + global_endpoint_manager.record_success(args[0]) else: result = ExecuteFunction(function, *args, **kwargs) if not client.last_response_headers: @@ -172,9 +176,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + # record the failure for circuit breaker tracking + global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -200,6 +204,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: + global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -291,7 +296,8 @@ def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -317,6 +323,7 @@ def send(self, request): # since request wasn't sent, raise exception immediately to be dealt with in client retry policies # This logic is based on the _retry.py file from azure-core if not _has_database_account_header(request.http_request.headers): + global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -329,7 +336,7 @@ def send(self, request): if (not _has_read_retryable_headers(request.http_request.headers) or _has_database_account_header(request.http_request.headers)): raise err - + global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -342,6 +349,7 @@ def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index 21a22ca89f61..e31682725828 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -250,4 +250,3 @@ def __eq__(self, other): def __hash__(self): return hash((self.partition_key_range, self.collection_rid)) - diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index edd15f20337f..b49185512fa4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -13,9 +13,10 @@ class ServiceRequestRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.total_in_region_retries = 1 self.in_region_retry_count = 0 @@ -45,7 +46,7 @@ def ShouldRetry(self): return False if self.global_endpoint_manager.is_circuit_breaker_applicable(self.request): - self.global_endpoint_manager.mark_partition_unavailable(self.request) + self.global_endpoint_manager.mark_partition_unavailable(self.request, self.pk_range_wrapper) else: refresh_cache = self.request.last_routed_location_endpoint_within_region is not None # This logic is for the last retry and mark the region unavailable @@ -99,7 +100,7 @@ def resolve_current_region_service_endpoint(self): # resolve the next service endpoint in the same region # since we maintain 2 endpoints per region for write operations self.request.route_to_location_with_preferred_location_flag(0, True) - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) # This function prepares the request to go to the next region def resolve_next_region_service_endpoint(self): @@ -113,7 +114,7 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(0, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) def mark_endpoint_unavailable(self, unavailable_endpoint, refresh_cache: bool): if _OperationType.IsReadOnlyOperation(self.request.operation_type): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py index 83a856f39d33..330ffb5929a5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py @@ -12,15 +12,17 @@ class ServiceResponseRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.failover_retry_count = 0 self.connection_policy = connection_policy self.request = args[0] if args else None if self.request: - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + pk_range_wrapper) self.logger = logging.getLogger('azure.cosmos.ServiceResponseRetryPolicy') def ShouldRetry(self): @@ -57,4 +59,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py index 1614f337de5b..e52fbe996e11 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py @@ -41,10 +41,11 @@ class _SessionRetryPolicy(object): Max_retry_attempt_count = 1 Retry_after_in_milliseconds = 0 - def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): + def __init__(self, endpoint_discovery_enable, global_endpoint_manager, pk_range_wrapper, *args): self.global_endpoint_manager = global_endpoint_manager self._max_retry_attempt_count = _SessionRetryPolicy.Max_retry_attempt_count self.session_token_retry_count = 0 + self.pk_range_wrapper = pk_range_wrapper self.retry_after_in_milliseconds = _SessionRetryPolicy.Retry_after_in_milliseconds self.endpoint_discovery_enable = endpoint_discovery_enable self.request = args[0] if args else None @@ -57,7 +58,8 @@ def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) def ShouldRetry(self, _exception): @@ -98,7 +100,8 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) return True @@ -113,6 +116,7 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 68e37caf1d9d..43516430bdb6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -65,7 +65,7 @@ def _request_body_from_data(data): return None -def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -104,7 +104,11 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) @@ -132,6 +136,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -142,6 +148,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index f70e27bae70c..69bc973c3346 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -9,9 +9,10 @@ class _TimeoutFailoverRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.retry_after_in_milliseconds = 500 self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper # If an account only has 1 region, then we still want to retry once on the same region self._max_retry_attempt_count = (len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) + 1) @@ -26,9 +27,6 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ - # record the failure for circuit breaker tracking - self.global_endpoint_manager.record_failure(self.request) - # we don't retry on write operations for timeouts or any internal server errors if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): return False @@ -57,4 +55,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 25f6ac203d85..4fda37ea0a87 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -34,7 +34,7 @@ from .._synchronized_request import _request_body_from_data, _replace_url_prefix -async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -73,7 +73,11 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 8fbdb0fb9a83..31ae9dc334cd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -48,13 +48,14 @@ DistributedTracingPolicy, ProxyPolicy) from azure.core.utils import CaseInsensitiveDict +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import ( + _GlobalPartitionEndpointManagerForCircuitBreakerAsync) from .. import _base as base from .._base import _set_properties_cache from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState -from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .._routing import routing_range from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 00438cc2214e..0fe666f1983c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,6 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from asyncio import CancelledError # pylint: disable=do-not-import-asyncio from typing import Tuple from azure.core.exceptions import AzureError @@ -35,6 +34,7 @@ from .. import exceptions from .._location_cache import LocationCache, current_time_millis from .._request_object import RequestObject +from .._routing.routing_range import PartitionKeyRangeWrapper # pylint: disable=protected-access @@ -71,7 +71,11 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def resolve_service_endpoint( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 71bb628c31a0..3aadc9b6aba0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -21,13 +21,16 @@ """Internal class for global endpoint manager for circuit breaker. """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import HttpHeaders + if TYPE_CHECKING: from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection @@ -42,32 +45,53 @@ class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManag def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).__init__(client) - self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) + + async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache.items(): # pylint: disable=protected-access + if properties["_rid"] == container_rid: + target_container_link = container_link + if not target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways + pk_range = await (self.client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(target_container_link, partition_key)) + return PartitionKeyRangeWrapper(pk_range, container_rid) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: - """ - Check if circuit breaker is applicable for a request. - """ return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) - def record_failure( + async def record_failure( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_failure(request) - - def resolve_service_endpoint(self, request): - request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) - return super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).resolve_service_endpoint(request) - - def mark_partition_unavailable(self, request: RequestObject) -> None: - """ - Mark the partition unavailable from the given request. - """ - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) and pk_range_wrapper: + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return (super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self) + .resolve_service_endpoint(request, pk_range_wrapper)) + + def mark_partition_unavailable( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) - def record_success( + async def record_success( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_success(request) + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index e1094736b88b..6e94c9260b8b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -28,7 +28,6 @@ from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import AsyncRetryPolicy -from ._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .. import _default_retry_policy, _database_account_retry_policy from .. import _endpoint_discovery_retry_policy from .. import _gone_retry_policy @@ -59,6 +58,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -74,17 +76,17 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -103,7 +105,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg try: if args: result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) - global_endpoint_manager.record_success(args[0]) + await global_endpoint_manager.record_success(args[0]) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) if not client.last_response_headers: @@ -172,9 +174,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + # record the failure for circuit breaker tracking + await global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy From 29305f46c2109abcceb80bd5e1b6737a124669c5 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 15:58:01 -0700 Subject: [PATCH 058/152] Added client level ExcludedLocation for async --- sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py index 683f16288cd3..647f6d59f615 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py @@ -84,6 +84,7 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config From c77b4e726b8ab2d7c7f6426c21ef7afb1e7b10e4 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 16:39:55 -0700 Subject: [PATCH 059/152] Update Live test settings --- sdk/cosmos/live-platform-matrix.json | 10 +++++----- sdk/cosmos/test-resources.bicep | 6 ------ 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 7a02486a8827..dc3ad3c32e17 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -90,11 +90,6 @@ } }, { - "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" - } - }, "WindowsConfig": { "Windows2022_312": { "OSVmImage": "env:WINDOWSVMIMAGE", @@ -104,6 +99,11 @@ "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion" } + }, + "ArmConfig": { + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" + } } } ] diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index b05dead26737..88abe955f8d8 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -41,12 +41,6 @@ var multiRegionConfiguration = [ failoverPriority: 1 isZoneRedundant: false } - { - locationName: 'East US 2' - provisioningState: 'Succeeded' - failoverPriority: 2 - isZoneRedundant: false - } ] var locationsConfiguration = (enableMultipleRegions ? multiRegionConfiguration : singleRegionConfiguration) var roleDefinitionId = guid(baseName, 'roleDefinitionId') From d82fa74255e899556f48f8797ff8afbe7ad595bc Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 16:40:32 -0700 Subject: [PATCH 060/152] Added Async tests --- .../tests/test_excluded_locations.py | 76 ++- .../tests/test_excluded_locations_async.py | 470 ++++++++++++++++++ 2 files changed, 504 insertions(+), 42 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 2d17bad85ba5..13e9ba713653 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -38,46 +38,42 @@ def emit(self, record): L1 = "West US 3" L2 = "West US" L3 = "East US 2" -L4 = "Central US" # L0 = "Default" # L1 = "East US 2" # L2 = "East US" # L3 = "West US 2" -# L4 = "Central US" CLIENT_ONLY_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No excluded location - [[L1, L2, L3], [], None], + [[L1, L2], [], None], # 1. Single excluded location - [[L1, L2, L3], [L1], None], - # 2. Multiple excluded locations - [[L1, L2, L3], [L1, L2], None], - # 3. Exclude all locations - [[L1, L2, L3], [L1, L2, L3], None], - # 4. Exclude a location not in preferred locations - [[L1, L2, L3], [L4], None], + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], ] CLIENT_AND_REQUEST_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No client excluded locations + a request excluded location - [[L1, L2, L3], [], [L1]], + [[L1, L2], [], [L1]], # 1. The same client and request excluded location - [[L1, L2, L3], [L1], [L1]], + [[L1, L2], [L1], [L1]], # 2. Less request excluded locations - [[L1, L2, L3], [L1, L2], [L1]], + [[L1, L2], [L1, L2], [L1]], # 3. More request excluded locations - [[L1, L2, L3], [L1], [L1, L2]], + [[L1, L2], [L1], [L1, L2]], # 4. All locations were excluded - [[L1, L2, L3], [L1, L2, L3], [L1, L2, L3]], + [[L1, L2], [L1, L2], [L1, L2]], # 5. No common excluded locations - [[L1, L2, L3], [L1], [L2, L3]], + [[L1, L2], [L1], [L2]], # 6. Request excluded location not in preferred locations - [[L1, L2, L3], [L1, L2, L3], [L4]], + [[L1, L2], [L1, L2], [L3]], # 7. Empty excluded locations, remove all client excluded locations - [[L1, L2, L3], [L1, L2], []], + [[L1, L2], [L1, L2], []], ] ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA @@ -88,15 +84,14 @@ def read_item_test_data(): client_only_output_data = [ [L1], # 0 [L2], # 1 - [L3], # 2 + [L1], # 2 [L1], # 3 - [L1] # 4 ] client_and_request_output_data = [ [L2], # 0 [L2], # 1 [L2], # 2 - [L3], # 3 + [L1], # 3 [L1], # 4 [L1], # 5 [L1], # 6 @@ -109,21 +104,20 @@ def read_item_test_data(): def query_items_change_feed_test_data(): client_only_output_data = [ - [L1, L1, L1], #0 - [L2, L2, L2], #1 - [L3, L3, L3], #2 - [L1, L1, L1], #3 - [L1, L1, L1] #4 + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 ] client_and_request_output_data = [ - [L1, L1, L2], #0 - [L2, L2, L2], #1 - [L3, L3, L2], #2 - [L2, L2, L3], #3 - [L1, L1, L1], #4 - [L2, L2, L1], #5 - [L1, L1, L1], #6 - [L3, L3, L1], #7 + [L1, L1, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L2, L2], #2 + [L2, L2, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L2, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data @@ -134,15 +128,14 @@ def replace_item_test_data(): client_only_output_data = [ [L1, L1], #0 [L2, L2], #1 - [L3, L3], #2 - [L1, L0], #3 - [L1, L1] #4 + [L1, L0], #2 + [L1, L1] #3 ] client_and_request_output_data = [ [L2, L2], #0 [L2, L2], #1 [L2, L2], #2 - [L3, L3], #3 + [L1, L0], #3 [L1, L0], #4 [L1, L1], #5 [L1, L1], #6 @@ -157,7 +150,6 @@ def patch_item_test_data(): client_only_output_data = [ [L1], #0 [L2], #1 - [L3], #2 [L0], #3 [L1] #4 ] @@ -165,7 +157,7 @@ def patch_item_test_data(): [L2], #0 [L2], #1 [L2], #2 - [L3], #3 + [L0], #3 [L0], #4 [L1], #5 [L1], #6 @@ -290,9 +282,9 @@ def test_query_items_change_feed(self, test_data): # API call: query_items_change_feed if request_excluded_locations is None: - list(container.query_items_change_feed()) + items = list(container.query_items_change_feed(start_time="Beginning")) else: - list(container.query_items_change_feed(excluded_locations=request_excluded_locations)) + items = list(container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)) # Verify endpoint locations self._verify_endpoint(client, expected_locations) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py new file mode 100644 index 000000000000..7564071de4f9 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -0,0 +1,470 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest +import pytest_asyncio + +from azure.cosmos.aio import CosmosClient +from azure.cosmos.partition_key import PartitionKey + + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +L0 = "Default" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" + +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2], [], None], + # 1. Single excluded location + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2], [L1, L2], [L1, L2]], + # 5. No common excluded locations + [[L1, L2], [L1], [L2, L3]], + # 6. Request excluded location not in preferred locations + [[L1, L2], [L1, L2], [L3]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1] # 3 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1], #0 + [L2, L2, L2], #1 + [L1, L1, L1], #2 + [L1, L1, L1] #3 + ] + client_and_request_output_data = [ + [L1, L2, L2], #0 + [L2, L2, L2], #1 + [L1, L2, L2], #2 + [L2, L1, L1], #3 + [L1, L1, L1], #4 + [L2, L1, L1], #5 + [L1, L1, L1], #6 + [L1, L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + test_client = CosmosClient(HOST, KEY) + container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + await container.create_item(body=TEST_ITEM) + + yield + await test_client.close() + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_and_teardown") +class TestExcludedLocations: + async def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = await client.create_database_if_not_exists(DATABASE_ID) + container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) + MOCK_HANDLER.reset() + + return client, db, container + + async def _verify_endpoint(self, client, expected_locations): + # get mapping for locations + location_mapping = (client.client_connection._global_endpoint_manager. + location_cache.account_locations_by_write_regional_routing_context) + default_endpoint = (client.client_connection._global_endpoint_manager. + location_cache.default_regional_routing_context.get_primary()) + + # get Request URL + msgs = MOCK_HANDLER.messages + req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] + + # get location + actual_locations = [] + for req_url in req_urls: + if req_url.startswith(default_endpoint): + actual_locations.append(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.append(location) + break + + assert actual_locations == expected_locations + + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_read_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_item + if request_excluded_locations is None: + await container.read_item(ITEM_ID, ITEM_PK_VALUE) + else: + await container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_read_all_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_all_items + if request_excluded_locations is None: + all_items = [item async for item in container.read_all_items()] + else: + all_items = [item async for item in container.read_all_items(excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_query_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items + if request_excluded_locations is None: + all_items = [item async for item in container.query_items(None)] + else: + all_items = [item async for item in container.query_items(None, excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) + async def test_query_items_change_feed(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items_change_feed + if request_excluded_locations is None: + all_items = [item async for item in container.query_items_change_feed(start_time="Beginning")] + else: + all_items = [item async for item in container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_replace_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: replace_item + if request_excluded_locations is None: + await container.replace_item(ITEM_ID, body=TEST_ITEM) + else: + await container.replace_item(ITEM_ID, body=TEST_ITEM, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_upsert_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: upsert_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + await container.upsert_item(body=body) + else: + await container.upsert_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_create_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: create_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + await container.create_item(body=body) + else: + await container.create_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_patch_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: patch_item + operations = [ + {"op": "add", "path": "/test_data", "value": f'Data-{str(uuid.uuid4())}'}, + ] + if request_excluded_locations is None: + await container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations) + else: + await container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_execute_item_batch(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: execute_item_batch + batch_operations = [] + for i in range(3): + batch_operations.append(("create", ({"id": f'Doc-{str(uuid.uuid4())}', PARTITION_KEY: ITEM_PK_VALUE},))) + + if request_excluded_locations is None: + await container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE,) + else: + await container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + #create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + await container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + MOCK_HANDLER.reset() + + # API call: read_item + if request_excluded_locations is None: + await container.delete_item(item_id, ITEM_PK_VALUE) + else: + await container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature + # @pytest.mark.parametrize('test_data', patch_item_test_data()) + # def test_delete_all_items_by_partition_key(self, test_data): + # # Init test variables + # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + # + # for multiple_write_locations in [True, False]: + # # Client setup + # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + # + # #create before delete + # item_id = f'doc2-{str(uuid.uuid4())}' + # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' + # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) + # MOCK_HANDLER.reset() + # + # # API call: read_item + # if request_excluded_locations is None: + # container.delete_all_items_by_partition_key(pk_value) + # else: + # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) + # + # # Verify endpoint locations + # if multiple_write_locations: + # self._verify_endpoint(client, expected_locations) + # else: + # self._verify_endpoint(client, [L1]) + +if __name__ == "__main__": + unittest.main() From 56108892418867fb21d6c7dad05c6ca0a2fbf982 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 16:55:49 -0700 Subject: [PATCH 061/152] Add more live tests for all other Python versions --- sdk/cosmos/live-platform-matrix.json | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index dc3ad3c32e17..6763c1c06562 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -91,6 +91,22 @@ }, { "WindowsConfig": { + "Windows2022_38": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.8", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion" + }, + "Windows2022_310": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.10", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion" + }, "Windows2022_312": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", From f4cb8b3ba9c1507793af77547281741e221b7af1 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 17:07:29 -0700 Subject: [PATCH 062/152] Fix Async test failure --- sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 13e9ba713653..2159c6c97425 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -176,7 +176,7 @@ def setup_and_teardown(): logger.setLevel(logging.DEBUG) container = cosmos_client.CosmosClient(HOST, KEY).get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) - container.create_item(body=TEST_ITEM) + container.upsert_item(body=TEST_ITEM) yield # Code to run after tests diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 7564071de4f9..b0079e753039 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -175,7 +175,7 @@ async def setup_and_teardown(): test_client = CosmosClient(HOST, KEY) container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) - await container.create_item(body=TEST_ITEM) + await container.upsert_item(body=TEST_ITEM) yield await test_client.close() From e98ab571cd172ade4ad46d04b691b7179de96775 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Tue, 8 Apr 2025 12:02:02 -0400 Subject: [PATCH 063/152] add test for failure_rate threshold --- ...n_endpoint_manager_circuit_breaker_core.py | 4 +- .../azure/cosmos/_partition_health_tracker.py | 122 +++++++++++------- ..._endpoint_manager_circuit_breaker_async.py | 16 ++- .../azure/cosmos/aio/_retry_utility_async.py | 8 +- .../tests/test_ppcb_sm_mrr_async.py | 92 +++++++++++-- 5 files changed, 171 insertions(+), 71 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 1f900215f08d..577b8410f435 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -26,7 +26,7 @@ from azure.cosmos import documents -from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._partition_health_tracker import _PartitionHealthTracker from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import EndpointOperationType, LocationCache from azure.cosmos._request_object import RequestObject @@ -44,7 +44,7 @@ class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): def __init__(self, client, location_cache: LocationCache): - self.partition_health_tracker = PartitionHealthTracker() + self.partition_health_tracker = _PartitionHealthTracker() self.location_cache = location_cache self.client = client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 04849ca1cb4b..f6c75c274ab6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -21,6 +21,7 @@ """Internal class for partition health tracker for circuit breaker. """ +import logging import os from typing import Dict, Set, Any from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper @@ -74,8 +75,18 @@ def reset_health_stats(self) -> None: self.read_consecutive_failure_count = 0 self.write_consecutive_failure_count = 0 + def __str__(self) -> str: + return (f"{self.__class__.__name__}: {self.unavailability_info}\n" + f"write failure count: {self.write_failure_count}\n" + f"read failure count: {self.read_failure_count}\n" + f"write success count: {self.write_success_count}\n" + f"read success count: {self.read_success_count}\n" + f"write consecutive failure count: {self.write_consecutive_failure_count}\n" + f"read consecutive failure count: {self.read_consecutive_failure_count}\n") -class PartitionHealthTracker(object): +logger = logging.getLogger("azure.cosmos._PartitionHealthTracker") + +class _PartitionHealthTracker(object): """ This internal class implements the logic for tracking health thresholds for a partition. """ @@ -83,7 +94,7 @@ class PartitionHealthTracker(object): def __init__(self) -> None: # partition -> regions -> health info - self.pkrange_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} + self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} self.last_refresh = current_time_millis() # TODO: @tvaron3 look for useful places to add logs @@ -97,26 +108,27 @@ def _transition_health_status_on_failure( pkrange_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: + logger.warn("{} has been marked as unavailable.".format(pkrange_wrapper)) current_time = current_time_millis() - if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + if pkrange_wrapper not in self.pk_range_wrapper_to_health_info: # healthy -> unhealthy tentative partition_health_info = _PartitionHealthInfo() partition_health_info.unavailability_info = { LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pkrange_wrapper_to_health_info[pkrange_wrapper] = { + self.pk_range_wrapper_to_health_info[pkrange_wrapper] = { location: partition_health_info } else: - region_to_partition_health = self.pkrange_wrapper_to_health_info[pkrange_wrapper] - if location in region_to_partition_health: + region_to_partition_health = self.pk_range_wrapper_to_health_info[pkrange_wrapper] + if location in region_to_partition_health and region_to_partition_health[location].unavailability_info: # healthy tentative -> unhealthy # if the operation type is not empty, we are in the healthy tentative state region_to_partition_health[location].unavailability_info[HEALTH_STATUS] = UNHEALTHY # reset the last unavailability check time stamp region_to_partition_health[location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ - = UNHEALTHY + = current_time else: # healthy -> unhealthy tentative # if the operation type is empty, we are in the unhealthy tentative state @@ -125,49 +137,55 @@ def _transition_health_status_on_failure( LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info + self.pk_range_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info def _transition_health_status_on_success( self, pkrange_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + if pkrange_wrapper in self.pk_range_wrapper_to_health_info: # healthy tentative -> healthy - self.pkrange_wrapper_to_health_info[pkrange_wrapper].pop(location, None) + self.pk_range_wrapper_to_health_info[pkrange_wrapper].pop(location, None) def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: current_time = current_time_millis() stale_partition_unavailability_check = int(os.environ.get(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 - if pk_range_wrapper in self.pkrange_wrapper_to_health_info: - for _, partition_health_info in self.pkrange_wrapper_to_health_info[pk_range_wrapper].items(): - elapsed_time = (current_time - - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) - current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] - # check if the partition key range is still unavailable - if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) - or (current_health_status == UNHEALTHY_TENTATIVE - and elapsed_time > INITIAL_UNAVAILABLE_TIME)): - # unhealthy or unhealthy tentative -> healthy tentative - partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE - - if current_time - self.last_refresh < REFRESH_INTERVAL: + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for _, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if partition_health_info.unavailability_info: + elapsed_time = (current_time - + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) + current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + # check if the partition key range is still unavailable + if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) + or (current_health_status == UNHEALTHY_TENTATIVE + and elapsed_time > INITIAL_UNAVAILABLE_TIME)): + # unhealthy or unhealthy tentative -> healthy tentative + partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + + if current_time - self.last_refresh > REFRESH_INTERVAL: # all partition stats reset every minute self._reset_partition_health_tracker_stats() - def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> Set[str]: - self._check_stale_partition_info(pkrange_wrapper) - if pkrange_wrapper in self.pkrange_wrapper_to_health_info: - return set(self.pkrange_wrapper_to_health_info[pkrange_wrapper].keys()) - return set() + def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> Set[str]: + self._check_stale_partition_info(pk_range_wrapper) + excluded_locations = set() + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if partition_health_info.unavailability_info: + health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + excluded_locations.add(location) + return excluded_locations def add_failure( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str ) -> None: @@ -176,12 +194,12 @@ def add_failure( Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) # Ensure that the health info dictionary is properly initialized. - if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: - self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} - if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() - health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] # Determine attribute names and environment variables based on the operation type. if operation_type == EndpointOperationType.WriteType: @@ -189,24 +207,24 @@ def add_failure( failure_attr = 'write_failure_count' consecutive_attr = 'write_consecutive_failure_count' env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE - default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT else: success_attr = 'read_success_count' failure_attr = 'read_failure_count' consecutive_attr = 'read_consecutive_failure_count' env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ - default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT # Increment failure and consecutive failure counts. setattr(health_info, failure_attr, getattr(health_info, failure_attr) + 1) setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) # Retrieve the consecutive failure threshold from the environment. - consecutive_failure_threshold = int(os.environ.get(env_key, default_consec_threshold)) + consecutive_failure_threshold = int(os.environ.get(env_key, default_consecutive_threshold)) # Call the threshold checker with the current stats. self._check_thresholds( - pkrange_wrapper, + pk_range_wrapper, getattr(health_info, success_attr), getattr(health_info, failure_attr), getattr(health_info, consecutive_attr), @@ -214,10 +232,13 @@ def add_failure( failure_rate_threshold, consecutive_failure_threshold ) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) + print(pk_range_wrapper) + print(location) def _check_thresholds( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, successes: int, failures: int, consecutive_failures: int, @@ -232,20 +253,20 @@ def _check_thresholds( failures, failure_rate_threshold ): - self._transition_health_status_on_failure(pkrange_wrapper, location) + self._transition_health_status_on_failure(pk_range_wrapper, location) # add to consecutive failures and check that threshold was not exceeded if consecutive_failures >= consecutive_failure_threshold: - self._transition_health_status_on_failure(pkrange_wrapper, location) + self._transition_health_status_on_failure(pk_range_wrapper, location) - def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + def add_success(self, pk_range_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: # Ensure that the health info dictionary is initialized. - if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: - self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} - if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() - health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] if operation_type == EndpointOperationType.WriteType: health_info.write_success_count += 1 @@ -253,10 +274,13 @@ def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: else: health_info.read_success_count += 1 health_info.read_consecutive_failure_count = 0 - self._transition_health_status_on_success(pkrange_wrapper, operation_type) + self._transition_health_status_on_success(pk_range_wrapper, operation_type) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) + print(pk_range_wrapper) + print(location) def _reset_partition_health_tracker_stats(self) -> None: - for locations in self.pkrange_wrapper_to_health_info.values(): + for locations in self.pk_range_wrapper_to_health_info.values(): for health_info in locations.values(): health_info.reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 3aadc9b6aba0..2ba64690da7b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -23,9 +23,10 @@ """ from typing import TYPE_CHECKING +from azure.cosmos import PartitionKey from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject @@ -50,18 +51,23 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] + partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key target_container_link = None for container_link, properties in self.client._container_properties_cache.items(): # pylint: disable=protected-access if properties["_rid"] == container_rid: target_container_link = container_link + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) + + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] if not target_container_link: raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways - pk_range = await (self.client._routing_map_provider # pylint: disable=protected-access - .get_overlapping_ranges(target_container_link, partition_key)) - return PartitionKeyRangeWrapper(pk_range, container_rid) + partition_ranges = await (self.client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(target_container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + return PartitionKeyRangeWrapper(partition_range, container_rid) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 6e94c9260b8b..5d4d680b50e4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -202,7 +202,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - global_endpoint_manager.record_failure(args[0]) + await global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -286,7 +286,7 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if not _has_database_account_header(request.http_request.headers): - global_endpoint_manager.record_failure(request_params) + await global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -304,7 +304,7 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): - global_endpoint_manager.record_failure(request_params) + await global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -319,7 +319,7 @@ async def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): - global_endpoint_manager.record_failure(request_params) + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index a8e39ae361da..8902a72d2f6d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -43,7 +43,7 @@ def errors(): for error_code in error_codes: errors_list.append(CosmosHttpResponseError( status_code=error_code, - message="Some injected fault.")) + message="Some injected error.")) errors_list.append(ServiceResponseError(message="Injected Service Response Error.")) return errors_list @@ -115,10 +115,11 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): - with pytest.raises(CosmosHttpResponseError or ServiceResponseError): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): await container.create_item(body=document_definition) + global_endpoint_manager = container.client_connection._global_endpoint_manager - TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) @@ -126,32 +127,101 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error - TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - for i in range(11): + for i in range(10): read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) request = read_resp.get_response_headers()["_request"] # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 1) + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + @pytest.mark.parametrize("error", errors()) + async def test_failure_rate_threshold_async(self, setup, error): + expected_read_region_uri = self.host + expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + document_definition_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + if i % 2 == 0: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await container.upsert_item(body=document_definition) + else: + await container.upsert_item(body=document_definition_2) + global_endpoint_manager = container.client_connection._global_endpoint_manager + + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + for i in range(20): + if i == 8: + read_resp = await container.read_item(item=document_definition_2['id'], + partition_key=document_definition_2['pk']) + else: + read_resp = await container.read_item(item=document_definition['id'], + partition_key=document_definition['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + + # the partition should have been marked as unavailable after breaking read threshold + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 @staticmethod - def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): - health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pkrange_wrapper_to_health_info + def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info unhealthy_partitions = 0 for pk_range_wrapper, location_to_health_info in health_info_map.items(): for location, health_info in location_to_health_info.items(): health_status = health_info.unavailability_info.get(HEALTH_STATUS) if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: unhealthy_partitions += 1 - assert unhealthy_partitions == expected_unhealthy_partitions + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + assert unhealthy_partitions == expected_unhealthy_partitions - # test_failure_rate_threshold - add service response error - # test service request marks only a partition unavailable not an entire region + # test_failure_rate_threshold - add service response error - across operation types + # test service request marks only a partition unavailable not an entire region - across operation types + # test cosmos client timeout if __name__ == '__main__': unittest.main() From 4f081681f2cd87a103bf65edd4b61a12283cfcb7 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 10:23:28 -0700 Subject: [PATCH 064/152] Fix live test failures --- .../tests/test_excluded_locations.py | 18 +++-- .../tests/test_excluded_locations_async.py | 66 ++++++++++--------- 2 files changed, 46 insertions(+), 38 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 2159c6c97425..9af367303107 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -168,6 +168,12 @@ def patch_item_test_data(): all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data +def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + container.create_item(body=body) + else: + container.create_item(body=body, excluded_locations=excluded_locations) + @pytest.fixture(scope="class", autouse=True) def setup_and_teardown(): print("Setup: This runs before any tests") @@ -344,10 +350,7 @@ def test_create_item(self, test_data): # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} - if request_excluded_locations is None: - container.create_item(body=body) - else: - container.create_item(body=body, excluded_locations=request_excluded_locations) + _create_item_with_excluded_locations(container, body, request_excluded_locations) # get location from mock_handler if multiple_write_locations: @@ -421,12 +424,13 @@ def test_delete_item(self, test_data): # Client setup client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - #create before delete + # create before delete item_id = f'doc2-{str(uuid.uuid4())}' - container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + _create_item_with_excluded_locations(container, body, request_excluded_locations) MOCK_HANDLER.reset() - # API call: read_item + # API call: delete_item if request_excluded_locations is None: container.delete_item(item_id, ITEM_PK_VALUE) else: diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index b0079e753039..dd6ce3776f68 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -80,20 +80,20 @@ def emit(self, record): def read_item_test_data(): client_only_output_data = [ - [L1], # 0 - [L2], # 1 - [L1], # 2 - [L1] # 3 + [L1, L1], # 0 + [L2, L2], # 1 + [L1, L1], # 2 + [L1, L1], # 3 ] client_and_request_output_data = [ - [L2], # 0 - [L2], # 1 - [L2], # 2 - [L1], # 3 - [L1], # 4 - [L1], # 5 - [L1], # 6 - [L1], # 7 + [L2, L2], # 0 + [L2, L2], # 1 + [L2, L2], # 2 + [L1, L1], # 3 + [L1, L1], # 4 + [L1, L1], # 5 + [L1, L1], # 6 + [L1, L1], # 7 ] all_output_test_data = client_only_output_data + client_and_request_output_data @@ -102,20 +102,20 @@ def read_item_test_data(): def query_items_change_feed_test_data(): client_only_output_data = [ - [L1, L1, L1], #0 - [L2, L2, L2], #1 - [L1, L1, L1], #2 - [L1, L1, L1] #3 + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 ] client_and_request_output_data = [ - [L1, L2, L2], #0 - [L2, L2, L2], #1 - [L1, L2, L2], #2 - [L2, L1, L1], #3 - [L1, L1, L1], #4 - [L2, L1, L1], #5 - [L1, L1, L1], #6 - [L1, L1, L1], #7 + [L1, L2, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L2, L2, L2], #2 + [L2, L1, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L1, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data @@ -166,6 +166,12 @@ def patch_item_test_data(): all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data +async def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + await container.create_item(body=body) + else: + await container.create_item(body=body, excluded_locations=excluded_locations) + @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_and_teardown(): print("Setup: This runs before any tests") @@ -344,10 +350,7 @@ async def test_create_item(self, test_data): # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} - if request_excluded_locations is None: - await container.create_item(body=body) - else: - await container.create_item(body=body, excluded_locations=request_excluded_locations) + await _create_item_with_excluded_locations(container, body, request_excluded_locations) # get location from mock_handler if multiple_write_locations: @@ -421,12 +424,13 @@ async def test_delete_item(self, test_data): # Client setup client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - #create before delete + # create before delete item_id = f'doc2-{str(uuid.uuid4())}' - await container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + await _create_item_with_excluded_locations(container, body, request_excluded_locations) MOCK_HANDLER.reset() - # API call: read_item + # API call: delete_item if request_excluded_locations is None: await container.delete_item(item_id, ITEM_PK_VALUE) else: From 36407c691b76d8ad8bcf2127f91f3c1cfa2dab2f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 8 Apr 2025 13:41:23 -0400 Subject: [PATCH 065/152] fix pylint and cspell --- .../azure/cosmos/_partition_health_tracker.py | 25 +++++------ ..._endpoint_manager_circuit_breaker_async.py | 9 ++-- .../tests/test_ppcb_sm_mrr_async.py | 45 +++++++++++++------ 3 files changed, 48 insertions(+), 31 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index f6c75c274ab6..4b7c0522f5c2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -99,29 +99,28 @@ def __init__(self) -> None: # TODO: @tvaron3 look for useful places to add logs - def mark_partition_unavailable(self, pkrange_wrapper: PartitionKeyRangeWrapper, location: str) -> None: + def mark_partition_unavailable(self, pk_range_wrapper: PartitionKeyRangeWrapper, location: str) -> None: # mark the partition key range as unavailable - self._transition_health_status_on_failure(pkrange_wrapper, location) + self._transition_health_status_on_failure(pk_range_wrapper, location) def _transition_health_status_on_failure( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - logger.warn("{} has been marked as unavailable.".format(pkrange_wrapper)) - current_time = current_time_millis() - if pkrange_wrapper not in self.pk_range_wrapper_to_health_info: + logger.warning("%s has been marked as unavailable.", pk_range_wrapper) current_time = current_time_millis() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: # healthy -> unhealthy tentative partition_health_info = _PartitionHealthInfo() partition_health_info.unavailability_info = { LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pk_range_wrapper_to_health_info[pkrange_wrapper] = { + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = { location: partition_health_info } else: - region_to_partition_health = self.pk_range_wrapper_to_health_info[pkrange_wrapper] + region_to_partition_health = self.pk_range_wrapper_to_health_info[pk_range_wrapper] if location in region_to_partition_health and region_to_partition_health[location].unavailability_info: # healthy tentative -> unhealthy # if the operation type is not empty, we are in the healthy tentative state @@ -137,16 +136,16 @@ def _transition_health_status_on_failure( LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pk_range_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = partition_health_info def _transition_health_status_on_success( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - if pkrange_wrapper in self.pk_range_wrapper_to_health_info: + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: # healthy tentative -> healthy - self.pk_range_wrapper_to_health_info[pkrange_wrapper].pop(location, None) + self.pk_range_wrapper_to_health_info[pk_range_wrapper].pop(location, None) def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: current_time = current_time_millis() @@ -178,7 +177,7 @@ def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): if partition_health_info.unavailability_info: health_status = partition_health_info.unavailability_info[HEALTH_STATUS] - if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY): excluded_locations.add(location) return excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 2ba64690da7b..c1badbd8a167 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -37,7 +37,7 @@ -class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): # pylint: disable=protected-access """ This internal class implements the logic for partition endpoint management for geo-replicated database accounts. @@ -54,17 +54,18 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyR partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): # pylint: disable=protected-access + for container_link, properties in self.client._container_properties_cache.items(): if properties["_rid"] == container_rid: target_container_link = container_link partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] if not target_container_link: raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways - partition_ranges = await (self.client._routing_map_provider # pylint: disable=protected-access + partition_ranges = await (self.client._routing_map_provider .get_overlapping_ranges(target_container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) return PartitionKeyRangeWrapper(partition_range, container_rid) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 8902a72d2f6d..e26c3a270d83 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -11,29 +11,29 @@ from azure.core.pipeline.transport._aiohttp import AioHttpTransport from azure.core.exceptions import ServiceResponseError +import test_config from azure.cosmos import PartitionKey from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError -from tests import test_config -from tests._fault_injection_transport_async import FaultInjectionTransportAsync +from _fault_injection_transport_async import FaultInjectionTransportAsync COLLECTION = "created_collection" @pytest_asyncio.fixture() async def setup(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" - client = CosmosClient(TestPPCBSmMrrAsync.host, TestPPCBSmMrrAsync.master_key, consistency_level="Session") - created_database = client.get_database_client(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) # print(TestPPCBSmMrrAsync.TEST_DATABASE_ID) - await client.create_database_if_not_exists(TestPPCBSmMrrAsync.TEST_DATABASE_ID) - created_collection = await created_database.create_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), offer_throughput=10000) yield { COLLECTION: created_collection } - await created_database.delete_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await created_database.delete_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" @@ -50,7 +50,7 @@ def errors(): @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") -class TestPPCBSmMrrAsync: +class TestPerPartitionCircuitBreakerSmMrrAsync: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey connectionPolicy = test_config.TestConfig.connectionPolicy @@ -70,6 +70,23 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] await method_client.close() + async def perform_write_operation(operation, container, id, pk): + document_definition = {'id': id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == "create": + await container.create_item(body=document_definition) + elif operation == "upsert": + await container.upsert_item(body=document_definition) + elif operation == "replace": + await container.replace_item(item=document_definition['id'], body=document_definition) + elif operation == "delete": + await container.delete_item(item=document_definition['id'], partition_key=document_definition['pk']) + elif operation == "read": + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + + async def create_custom_transport_sm_mrr(self): custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region @@ -119,7 +136,7 @@ async def test_consecutive_failure_threshold_async(self, setup, error): await container.create_item(body=document_definition) global_endpoint_manager = container.client_connection._global_endpoint_manager - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) @@ -127,7 +144,7 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(10): read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) @@ -136,7 +153,7 @@ async def test_consecutive_failure_threshold_async(self, setup, error): assert request.url.startswith(expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) @pytest.mark.parametrize("error", errors()) async def test_failure_rate_threshold_async(self, setup, error): @@ -173,7 +190,7 @@ async def test_failure_rate_threshold_async(self, setup, error): await container.upsert_item(body=document_definition_2) global_endpoint_manager = container.client_connection._global_endpoint_manager - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) @@ -181,7 +198,7 @@ async def test_failure_rate_threshold_async(self, setup, error): # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # lower minimum requests for testing global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 try: @@ -197,7 +214,7 @@ async def test_failure_rate_threshold_async(self, setup, error): assert request.url.startswith(expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) finally: # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 From 4e2fd6b691478cd75efaf2d62f65b82f5cc23416 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 10:46:37 -0700 Subject: [PATCH 066/152] Fix live test failures --- .../tests/test_excluded_locations_async.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index dd6ce3776f68..4a39b6a78c2c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -80,20 +80,20 @@ def emit(self, record): def read_item_test_data(): client_only_output_data = [ - [L1, L1], # 0 - [L2, L2], # 1 - [L1, L1], # 2 - [L1, L1], # 3 + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1], # 3 ] client_and_request_output_data = [ - [L2, L2], # 0 - [L2, L2], # 1 - [L2, L2], # 2 - [L1, L1], # 3 - [L1, L1], # 4 - [L1, L1], # 5 - [L1, L1], # 6 - [L1, L1], # 7 + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 ] all_output_test_data = client_only_output_data + client_and_request_output_data From 1baf872d58142af58af08dffb6e2c20a8cad1771 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 8 Apr 2025 13:51:20 -0400 Subject: [PATCH 067/152] fix pylint --- .../azure-cosmos/azure/cosmos/_partition_health_tracker.py | 3 ++- ...global_partition_endpoint_manager_circuit_breaker_async.py | 4 ++-- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 4b7c0522f5c2..9f30bac2bd2c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -108,7 +108,8 @@ def _transition_health_status_on_failure( pk_range_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - logger.warning("%s has been marked as unavailable.", pk_range_wrapper) current_time = current_time_millis() + logger.warning("%s has been marked as unavailable.", pk_range_wrapper) + current_time = current_time_millis() if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: # healthy -> unhealthy tentative partition_health_info = _PartitionHealthInfo() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index c1badbd8a167..8fe8b9a79afe 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -36,8 +36,8 @@ from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection - -class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): # pylint: disable=protected-access +# pylint: disable=protected-access +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): """ This internal class implements the logic for partition endpoint management for geo-replicated database accounts. diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index e26c3a270d83..4a6b37b23f0b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -236,7 +236,8 @@ def validate_unhealthy_partitions(global_endpoint_manager, assert unhealthy_partitions == expected_unhealthy_partitions - # test_failure_rate_threshold - add service response error - across operation types + # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again + # # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout From e0dab2977be9f6519644aa15ef5569ddd14ad83c Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 11:11:30 -0700 Subject: [PATCH 068/152] Fix live test failures --- .../tests/test_excluded_locations_async.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 4a39b6a78c2c..e5c1dcfeed26 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -100,6 +100,28 @@ def read_item_test_data(): all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data +def read_all_item_test_data(): + client_only_output_data = [ + [L1, L1], # 0 + [L2, L2], # 1 + [L1, L1], # 2 + [L1, L1], # 3 + ] + client_and_request_output_data = [ + [L2, L2], # 0 + [L2, L2], # 1 + [L2, L2], # 2 + [L1, L1], # 3 + [L1, L1], # 4 + [L1, L1], # 5 + [L1, L1], # 6 + [L1, L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + def query_items_change_feed_test_data(): client_only_output_data = [ [L1, L1, L1, L1], #0 @@ -243,7 +265,7 @@ async def test_read_item(self, test_data): # Verify endpoint locations await self._verify_endpoint(client, expected_locations) - @pytest.mark.parametrize('test_data', read_item_test_data()) + @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_read_all_items(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -260,7 +282,7 @@ async def test_read_all_items(self, test_data): # Verify endpoint locations await self._verify_endpoint(client, expected_locations) - @pytest.mark.parametrize('test_data', read_item_test_data()) + @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_query_items(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data From 798c12f513f2822e8e270ffc00111d05f57fdd28 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 11:38:36 -0700 Subject: [PATCH 069/152] Add test_delete_all_items_by_partition_key --- .../tests/test_excluded_locations.py | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 9af367303107..0a63136700c0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -442,33 +442,34 @@ def test_delete_item(self, test_data): else: self._verify_endpoint(client, [L1]) - # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature - # @pytest.mark.parametrize('test_data', patch_item_test_data()) - # def test_delete_all_items_by_partition_key(self, test_data): - # # Init test variables - # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data - # - # for multiple_write_locations in [True, False]: - # # Client setup - # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - # - # #create before delete - # item_id = f'doc2-{str(uuid.uuid4())}' - # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' - # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) - # MOCK_HANDLER.reset() - # - # # API call: read_item - # if request_excluded_locations is None: - # container.delete_all_items_by_partition_key(pk_value) - # else: - # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) - # - # # Verify endpoint locations - # if multiple_write_locations: - # self._verify_endpoint(client, expected_locations) - # else: - # self._verify_endpoint(client, [L1]) + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_delete_all_items_by_partition_key(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' + body = {PARTITION_KEY: pk_value, 'id': item_id} + _create_item_with_excluded_locations(container, body, request_excluded_locations) + MOCK_HANDLER.reset() + + # API call: delete_item + if request_excluded_locations is None: + container.delete_all_items_by_partition_key(pk_value) + else: + container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) if __name__ == "__main__": unittest.main() From 2c5b8fce52682014b4214f930a29f2b97b13e91c Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 15:01:29 -0700 Subject: [PATCH 070/152] Remove test_delete_all_items_by_partition_key --- .../tests/test_excluded_locations.py | 29 ------------------- .../tests/test_excluded_locations_async.py | 28 ------------------ 2 files changed, 57 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 0a63136700c0..d99f9a3b3fda 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -442,34 +442,5 @@ def test_delete_item(self, test_data): else: self._verify_endpoint(client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) - def test_delete_all_items_by_partition_key(self, test_data): - # Init test variables - preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data - - for multiple_write_locations in [True, False]: - # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations, - multiple_write_locations) - - # create before delete - item_id = f'doc2-{str(uuid.uuid4())}' - pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' - body = {PARTITION_KEY: pk_value, 'id': item_id} - _create_item_with_excluded_locations(container, body, request_excluded_locations) - MOCK_HANDLER.reset() - - # API call: delete_item - if request_excluded_locations is None: - container.delete_all_items_by_partition_key(pk_value) - else: - container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) - - # Verify endpoint locations - if multiple_write_locations: - self._verify_endpoint(client, expected_locations) - else: - self._verify_endpoint(client, [L1]) - if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index e5c1dcfeed26..109f9ff7207a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -464,33 +464,5 @@ async def test_delete_item(self, test_data): else: await self._verify_endpoint(client, [L1]) - # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature - # @pytest.mark.parametrize('test_data', patch_item_test_data()) - # def test_delete_all_items_by_partition_key(self, test_data): - # # Init test variables - # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data - # - # for multiple_write_locations in [True, False]: - # # Client setup - # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - # - # #create before delete - # item_id = f'doc2-{str(uuid.uuid4())}' - # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' - # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) - # MOCK_HANDLER.reset() - # - # # API call: read_item - # if request_excluded_locations is None: - # container.delete_all_items_by_partition_key(pk_value) - # else: - # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) - # - # # Verify endpoint locations - # if multiple_write_locations: - # self._verify_endpoint(client, expected_locations) - # else: - # self._verify_endpoint(client, [L1]) - if __name__ == "__main__": unittest.main() From 739e09006c8023ed9b415d6d08a92234ed4890e8 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 8 Apr 2025 20:11:37 -0400 Subject: [PATCH 071/152] fix and add tests --- .../azure/cosmos/_retry_utility.py | 2 +- .../tests/test_ppcb_sm_mrr_async.py | 167 ++++++++++++------ 2 files changed, 110 insertions(+), 59 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 44c6b088696e..7d27885f10db 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -59,7 +59,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin :rtype: tuple of (dict, dict) """ pk_range_wrapper = None - if global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 4a6b37b23f0b..c34c6ee8e731 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -24,7 +24,6 @@ async def setup(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key, consistency_level="Session") created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) - # print(TestPPCBSmMrrAsync.TEST_DATABASE_ID) await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), @@ -37,15 +36,24 @@ async def setup(): await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" -def errors(): - errors_list = [] +def operations_and_errors(): + write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] + read_operations = ["read", "query", "changefeed"] + errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: - errors_list.append(CosmosHttpResponseError( + errors.append(CosmosHttpResponseError( status_code=error_code, message="Some injected error.")) - errors_list.append(ServiceResponseError(message="Injected Service Response Error.")) - return errors_list + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for write_operation in write_operations: + for read_operation in read_operations: + for error in errors: + params.append((write_operation, read_operation, error)) + + return params + @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -70,21 +78,52 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] await method_client.close() - async def perform_write_operation(operation, container, id, pk): - document_definition = {'id': id, - 'pk': pk, - 'name': 'sample document', - 'key': 'value'} + @staticmethod + async def perform_write_operation(operation, container, doc_id, pk): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} if operation == "create": - await container.create_item(body=document_definition) + await container.create_item(body=doc) elif operation == "upsert": - await container.upsert_item(body=document_definition) + await container.upsert_item(body=doc) elif operation == "replace": - await container.replace_item(item=document_definition['id'], body=document_definition) + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + await container.replace_item(item=doc['id'], body=new_doc) elif operation == "delete": - await container.delete_item(item=document_definition['id'], partition_key=document_definition['pk']) - elif operation == "read": - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + await container.create_item(body=doc) + await container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == "patch": + operations = [{"op": "incr", "path": "/company", "value": 3}] + await container.patch_item(item=doc['id'], partition_key=doc['pk'], operations=operations) + elif operation == "batch": + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + await container.execute_item_batch(batch_operations, partition_key=doc['pk']) + + @staticmethod + async def perform_read_operation(operation, container, doc_id, pk, expected_read_region_uri): + if operation == "read": + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + elif operation == "query": + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for _ in container.query_items(query=query, parameters=parameters): + pass + elif operation == "changefeed": + async for _ in container.query_items_change_feed(): + pass async def create_custom_transport_sm_mrr(self): @@ -110,8 +149,8 @@ async def create_custom_transport_sm_mrr(self): emulator_as_multi_region_sm_account_transformation) return custom_transport - @pytest.mark.parametrize("error", errors()) - async def test_consecutive_failure_threshold_async(self, setup, error): + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) + async def test_consecutive_failure_threshold_async(self, setup, write_operation, read_operation, error): expected_read_region_uri = self.host expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = await self.create_custom_transport_sm_mrr() @@ -133,7 +172,10 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await container.create_item(body=document_definition) + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, + container, + document_definition['id'], + document_definition['pk']) global_endpoint_manager = container.client_connection._global_endpoint_manager TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) @@ -142,21 +184,26 @@ async def test_consecutive_failure_threshold_async(self, setup, error): await setup[COLLECTION].create_item(body=document_definition) # reads should fail over and only the relevant partition should be marked as unavailable - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) # partition should not have been marked unavailable after one error TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(10): - read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_read_region_uri) + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) - @pytest.mark.parametrize("error", errors()) - async def test_failure_rate_threshold_async(self, setup, error): + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) + async def test_failure_rate_threshold_async(self, setup, write_operation, read_operation, error): expected_read_region_uri = self.host expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = await self.create_custom_transport_sm_mrr() @@ -166,10 +213,10 @@ async def test_failure_rate_threshold_async(self, setup, error): 'pk': 'pk1', 'name': 'sample document', 'key': 'value'} - document_definition_2 = {'id': str(uuid.uuid4()), - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) custom_transport.add_fault(predicate, @@ -180,45 +227,50 @@ async def test_failure_rate_threshold_async(self, setup, error): custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) container = custom_setup['col'] - - # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition - for i in range(6): - if i % 2 == 0: - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await container.upsert_item(body=document_definition) - else: - await container.upsert_item(body=document_definition_2) global_endpoint_manager = container.client_connection._global_endpoint_manager - - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - - # create item with client without fault injection - await setup[COLLECTION].create_item(body=document_definition) - - # reads should fail over and only the relevant partition should be marked as unavailable - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) - # partition should not have been marked unavailable after one error - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # lower minimum requests for testing global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, + container, + document_definition['id'], + document_definition['pk']) + + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(20): if i == 8: - read_resp = await container.read_item(item=document_definition_2['id'], - partition_key=document_definition_2['pk']) + read_resp = await container.read_item(item=doc_2['id'], + partition_key=doc_2['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) else: - read_resp = await container.read_item(item=document_definition['id'], - partition_key=document_definition['pk']) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_read_region_uri) - + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) finally: # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + # look at the urls for verifying fall back + @staticmethod def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): @@ -237,7 +289,6 @@ def validate_unhealthy_partitions(global_endpoint_manager, assert unhealthy_partitions == expected_unhealthy_partitions # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again - # # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout From d5c380a4098d8aaa3dcfeee65f6467d104714671 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Tue, 8 Apr 2025 23:33:54 -0400 Subject: [PATCH 072/152] add collection rid to batch --- ..._endpoint_manager_circuit_breaker_async.py | 39 +++++++++++-------- .../azure-cosmos/azure/cosmos/container.py | 2 + .../tests/test_ppcb_sm_mrr_async.py | 2 +- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 8fe8b9a79afe..d9a06317804e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -51,23 +51,28 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key_value = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], - kind=partition_key_definition["kind"]) - - epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] - if not target_container_link: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") - # TODO: @tvaron3 check different clients and create them in different ways - partition_ranges = await (self.client._routing_map_provider - .get_overlapping_ranges(target_container_link, epk_range)) - partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + print(request.headers) + if request.headers.get(HttpHeaders.PartitionKey): + partition_key_value = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache.items(): + if properties["_rid"] == container_rid: + target_container_link = container_link + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) + + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] + if not target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways + partition_ranges = await (self.client._routing_map_provider + .get_overlapping_ranges(target_container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + elif request.headers.get(HttpHeaders.PartitionKeyRangeID): + pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + return PartitionKeyRangeWrapper(partition_range, container_rid) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 3f48824c9f7d..c659ae746b4a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -1088,6 +1088,8 @@ def execute_item_batch( request_options = build_options(kwargs) request_options["partitionKey"] = self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True + container_properties = self._get_properties() + request_options["containerRID"] = container_properties["_rid"] return self.client_connection.Batch( collection_link=self.container_link, batch_operations=batch_operations, options=request_options, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index c34c6ee8e731..0be283a289cc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -99,7 +99,7 @@ async def perform_write_operation(operation, container, doc_id, pk): await container.delete_item(item=doc['id'], partition_key=doc['pk']) elif operation == "patch": operations = [{"op": "incr", "path": "/company", "value": 3}] - await container.patch_item(item=doc['id'], partition_key=doc['pk'], operations=operations) + await container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) elif operation == "batch": batch_operations = [ ("create", (doc, )), From e7f7265e7548b3ed3a215e0c0621d9b28ae711fe Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 9 Apr 2025 10:43:18 -0400 Subject: [PATCH 073/152] add partition key range id to partition key range to cache --- .../_routing/aio/routing_map_provider.py | 38 ++++++++++++++++--- ..._endpoint_manager_circuit_breaker_async.py | 30 +++++++++------ .../tests/test_ppcb_sm_mrr_async.py | 3 +- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index e70ae355c495..f59513d05d24 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -22,11 +22,13 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +from typing import Dict, Any from ... import _base from ..collection_routing_map import CollectionRoutingMap from .. import routing_range + # pylint: disable=protected-access @@ -58,13 +60,21 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** :return: List of overlapping partition key ranges. :rtype: list """ - cl = self._documentClient - collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) + async def initialize_collection_routing_map_if_needed( + self, + collection_link: str, + collection_id: str, + **kwargs: Dict[str, Any] + ): + client = self._documentClient collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: - collection_pk_ranges = [pk async for pk in cl._ReadPartitionKeyRanges(collection_link, **kwargs)] + collection_pk_ranges = [pk async for pk in client._ReadPartitionKeyRanges(collection_link, **kwargs)] # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. @@ -72,8 +82,18 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( [(r, True) for r in collection_pk_ranges], collection_id ) - self._collection_routing_map_by_item[collection_id] = collection_routing_map - return collection_routing_map.get_overlapping_ranges(partition_key_ranges) + self._collection_routing_map_by_item[collection_id] = collection_routing_map + + async def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id(partition_key_range_id) @staticmethod def _discard_parent_ranges(partitionKeyRanges): @@ -196,3 +216,11 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** pass return target_partition_key_ranges + + async def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + return await super().get_range_by_partition_key_range_id(collection_link, partition_key_range_id, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index d9a06317804e..77757b6a2752 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -49,29 +49,35 @@ def __init__(self, client: "CosmosClientConnection"): self.global_partition_endpoint_manager_core = ( _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) - async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + async def create_pk_range_wrapper(self, request: RequestObject, kwargs) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) + target_container_link = None + partition_key = None + # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache + for container_link, properties in self.client._container_properties_cache.items(): + if properties["_rid"] == container_rid: + target_container_link = container_link + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) + + if not target_container_link or not partition_key: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + if request.headers.get(HttpHeaders.PartitionKey): partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], - kind=partition_key_definition["kind"]) - - epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] - if not target_container_link: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] partition_ranges = await (self.client._routing_map_provider .get_overlapping_ranges(target_container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) elif request.headers.get(HttpHeaders.PartitionKeyRangeID): pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + range = await (self.client._routing_map_provider + .get_range_by_partition_key_range_id(target_container_link, pk_range_id)) + partition_range = Range.PartitionKeyRangeToRange(range) return PartitionKeyRangeWrapper(partition_range, container_rid) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 0be283a289cc..99f7a822c29a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -201,6 +201,7 @@ async def test_consecutive_failure_threshold_async(self, setup, write_operation, # the partition should have been marked as unavailable after breaking read threshold TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + # test recovering the partition again @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) async def test_failure_rate_threshold_async(self, setup, write_operation, read_operation, error): @@ -269,7 +270,7 @@ async def test_failure_rate_threshold_async(self, setup, write_operation, read_o # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - # look at the urls for verifying fall back + # look at the urls for verifying fall back and use another id for same partition @staticmethod def validate_unhealthy_partitions(global_endpoint_manager, From 38f80331b1c32162d84ab7c718b4dd878ec046be Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 9 Apr 2025 11:43:43 -0400 Subject: [PATCH 074/152] address failures --- ...obal_partition_endpoint_manager_circuit_breaker_async.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 77757b6a2752..5231ed5c06c4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -49,7 +49,7 @@ def __init__(self, client: "CosmosClientConnection"): self.global_partition_endpoint_manager_core = ( _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) - async def create_pk_range_wrapper(self, request: RequestObject, kwargs) -> PartitionKeyRangeWrapper: + async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) target_container_link = None diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 99f7a822c29a..4b2b33512bd0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -119,8 +119,9 @@ async def perform_read_operation(operation, container, doc_id, pk, expected_read elif operation == "query": query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] - async for _ in container.query_items(query=query, parameters=parameters): + async for _ in container.query_items(query=query, partition_key=pk, parameters=parameters): pass + # need to do query with no pk and with feed range elif operation == "changefeed": async for _ in container.query_items_change_feed(): pass @@ -149,6 +150,9 @@ async def create_custom_transport_sm_mrr(self): emulator_as_multi_region_sm_account_transformation) return custom_transport + + # split this into write and read tests + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) async def test_consecutive_failure_threshold_async(self, setup, write_operation, read_operation, error): expected_read_region_uri = self.host From 828a99b5b044ea8b5d70b368eb8d2189afe72932 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 9 Apr 2025 11:52:22 -0400 Subject: [PATCH 075/152] update tests --- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 4b2b33512bd0..113574a7725f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -163,7 +163,7 @@ async def test_consecutive_failure_threshold_async(self, setup, write_operation, 'pk': 'pk1', 'name': 'sample document', 'key': 'value'} - predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, From 2b9b58fbc885781cda6c039bb231853d6c652b79 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Wed, 9 Apr 2025 19:40:18 -0700 Subject: [PATCH 076/152] Added missing doc for excluded_locations in async client --- sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py index 647f6d59f615..e5e526670629 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py @@ -162,6 +162,7 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client From 1c98b48f15660200d77a87e8811080bde636df42 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Wed, 9 Apr 2025 19:40:57 -0700 Subject: [PATCH 077/152] Remove duplicate functions --- .../tests/test_excluded_locations.py | 123 +++++++++--------- .../tests/test_excluded_locations_async.py | 101 ++++++-------- 2 files changed, 99 insertions(+), 125 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index d99f9a3b3fda..49b7f0553871 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -188,51 +188,50 @@ def setup_and_teardown(): # Code to run after tests print("Teardown: This runs after all tests") -@pytest.mark.cosmosMultiRegion -class TestExcludedLocations: - def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): - client = cosmos_client.CosmosClient(HOST, KEY, - preferred_locations=preferred_locations, - excluded_locations=client_excluded_locations, - multiple_write_locations=multiple_write_locations) - db = client.get_database_client(DATABASE_ID) - container = db.get_container_client(CONTAINER_ID) - MOCK_HANDLER.reset() - - return client, db, container - - def _verify_endpoint(self, client, expected_locations): - # get mapping for locations - location_mapping = (client.client_connection._global_endpoint_manager. - location_cache.account_locations_by_write_regional_routing_context) - default_endpoint = (client.client_connection._global_endpoint_manager. - location_cache.default_regional_routing_context.get_primary()) - - # get Request URL - msgs = MOCK_HANDLER.messages - req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] - - # get location - actual_locations = [] - for req_url in req_urls: - if req_url.startswith(default_endpoint): - actual_locations.append(L0) - else: - for endpoint in location_mapping: - if req_url.startswith(endpoint): - location = location_mapping[endpoint] - actual_locations.append(location) - break +def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = cosmos_client.CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = client.get_database_client(DATABASE_ID) + container = db.get_container_client(CONTAINER_ID) + MOCK_HANDLER.reset() + + return client, db, container + +def _verify_endpoint(messages, client, expected_locations): + # get mapping for locations + location_mapping = (client.client_connection._global_endpoint_manager. + location_cache.account_locations_by_write_regional_routing_context) + default_endpoint = (client.client_connection._global_endpoint_manager. + location_cache.default_regional_routing_context.get_primary()) + + # get Request URL + req_urls = [url.replace("Request URL: '", "") for url in messages if 'Request URL:' in url] + + # get location + actual_locations = [] + for req_url in req_urls: + if req_url.startswith(default_endpoint): + actual_locations.append(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.append(location) + break - assert actual_locations == expected_locations + assert actual_locations == expected_locations +@pytest.mark.cosmosMultiRegion +class TestExcludedLocations: @pytest.mark.parametrize('test_data', read_item_test_data()) def test_read_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: read_item if request_excluded_locations is None: @@ -241,7 +240,7 @@ def test_read_item(self, test_data): container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) def test_read_all_items(self, test_data): @@ -249,7 +248,7 @@ def test_read_all_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: read_all_items if request_excluded_locations is None: @@ -258,7 +257,7 @@ def test_read_all_items(self, test_data): list(container.read_all_items(excluded_locations=request_excluded_locations)) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) def test_query_items(self, test_data): @@ -266,7 +265,7 @@ def test_query_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: query_items if request_excluded_locations is None: @@ -275,7 +274,7 @@ def test_query_items(self, test_data): list(container.query_items(None, excluded_locations=request_excluded_locations)) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) def test_query_items_change_feed(self, test_data): @@ -284,7 +283,7 @@ def test_query_items_change_feed(self, test_data): # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: query_items_change_feed if request_excluded_locations is None: @@ -293,7 +292,7 @@ def test_query_items_change_feed(self, test_data): items = list(container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', replace_item_test_data()) @@ -303,7 +302,7 @@ def test_replace_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: replace_item if request_excluded_locations is None: @@ -313,9 +312,9 @@ def test_replace_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_upsert_item(self, test_data): @@ -324,7 +323,7 @@ def test_upsert_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: upsert_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -335,9 +334,9 @@ def test_upsert_item(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_create_item(self, test_data): @@ -346,7 +345,7 @@ def test_create_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -354,9 +353,9 @@ def test_create_item(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_patch_item(self, test_data): @@ -365,7 +364,7 @@ def test_patch_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: patch_item @@ -382,9 +381,9 @@ def test_patch_item(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_execute_item_batch(self, test_data): @@ -393,7 +392,7 @@ def test_execute_item_batch(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: execute_item_batch @@ -411,9 +410,9 @@ def test_execute_item_batch(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_delete_item(self, test_data): @@ -422,7 +421,7 @@ def test_delete_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # create before delete item_id = f'doc2-{str(uuid.uuid4())}' @@ -438,9 +437,9 @@ def test_delete_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 109f9ff7207a..50c0b69acd76 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -10,7 +10,7 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.partition_key import PartitionKey - +from test_excluded_locations import _verify_endpoint class MockHandler(logging.Handler): def __init__(self): @@ -208,53 +208,28 @@ async def setup_and_teardown(): yield await test_client.close() +async def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = await client.create_database_if_not_exists(DATABASE_ID) + container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) + MOCK_HANDLER.reset() + + return client, db, container + @pytest.mark.cosmosMultiRegion @pytest.mark.asyncio @pytest.mark.usefixtures("setup_and_teardown") class TestExcludedLocations: - async def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): - client = CosmosClient(HOST, KEY, - preferred_locations=preferred_locations, - excluded_locations=client_excluded_locations, - multiple_write_locations=multiple_write_locations) - db = await client.create_database_if_not_exists(DATABASE_ID) - container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) - MOCK_HANDLER.reset() - - return client, db, container - - async def _verify_endpoint(self, client, expected_locations): - # get mapping for locations - location_mapping = (client.client_connection._global_endpoint_manager. - location_cache.account_locations_by_write_regional_routing_context) - default_endpoint = (client.client_connection._global_endpoint_manager. - location_cache.default_regional_routing_context.get_primary()) - - # get Request URL - msgs = MOCK_HANDLER.messages - req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] - - # get location - actual_locations = [] - for req_url in req_urls: - if req_url.startswith(default_endpoint): - actual_locations.append(L0) - else: - for endpoint in location_mapping: - if req_url.startswith(endpoint): - location = location_mapping[endpoint] - actual_locations.append(location) - break - - assert actual_locations == expected_locations - @pytest.mark.parametrize('test_data', read_item_test_data()) async def test_read_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: read_item if request_excluded_locations is None: @@ -263,7 +238,7 @@ async def test_read_item(self, test_data): await container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_read_all_items(self, test_data): @@ -271,7 +246,7 @@ async def test_read_all_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: read_all_items if request_excluded_locations is None: @@ -280,7 +255,7 @@ async def test_read_all_items(self, test_data): all_items = [item async for item in container.read_all_items(excluded_locations=request_excluded_locations)] # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_query_items(self, test_data): @@ -288,7 +263,7 @@ async def test_query_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: query_items if request_excluded_locations is None: @@ -297,7 +272,7 @@ async def test_query_items(self, test_data): all_items = [item async for item in container.query_items(None, excluded_locations=request_excluded_locations)] # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) async def test_query_items_change_feed(self, test_data): @@ -306,7 +281,7 @@ async def test_query_items_change_feed(self, test_data): # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: query_items_change_feed if request_excluded_locations is None: @@ -315,7 +290,7 @@ async def test_query_items_change_feed(self, test_data): all_items = [item async for item in container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)] # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', replace_item_test_data()) @@ -325,7 +300,7 @@ async def test_replace_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: replace_item if request_excluded_locations is None: @@ -335,9 +310,9 @@ async def test_replace_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_upsert_item(self, test_data): @@ -346,7 +321,7 @@ async def test_upsert_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: upsert_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -357,9 +332,9 @@ async def test_upsert_item(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_create_item(self, test_data): @@ -368,7 +343,7 @@ async def test_create_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -376,9 +351,9 @@ async def test_create_item(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_patch_item(self, test_data): @@ -387,7 +362,7 @@ async def test_patch_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: patch_item @@ -404,9 +379,9 @@ async def test_patch_item(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_execute_item_batch(self, test_data): @@ -415,7 +390,7 @@ async def test_execute_item_batch(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: execute_item_batch @@ -433,9 +408,9 @@ async def test_execute_item_batch(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_delete_item(self, test_data): @@ -444,7 +419,7 @@ async def test_delete_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # create before delete item_id = f'doc2-{str(uuid.uuid4())}' @@ -460,9 +435,9 @@ async def test_delete_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) if __name__ == "__main__": unittest.main() From b5accfa44f8ac6d74681444ce09dd26404401d40 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 10 Apr 2025 18:55:29 -0400 Subject: [PATCH 078/152] add more operations --- .../tests/test_ppcb_sm_mrr_async.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 113574a7725f..ee5e3fcd6fa9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -38,7 +38,7 @@ async def setup(): def operations_and_errors(): write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] - read_operations = ["read", "query", "changefeed"] + read_operations = ["read", "query", "changefeed", "read_all_items", "delete_all_items_by_partition_key"] errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -108,6 +108,11 @@ async def perform_write_operation(operation, container, doc_id, pk): ("upsert", (doc,)), ] await container.execute_item_batch(batch_operations, partition_key=doc['pk']) + elif operation == "delete_all_items_by_partition_key": + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.delete_all_items_by_partition_key(pk) @staticmethod async def perform_read_operation(operation, container, doc_id, pk, expected_read_region_uri): @@ -119,12 +124,18 @@ async def perform_read_operation(operation, container, doc_id, pk, expected_read elif operation == "query": query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] - async for _ in container.query_items(query=query, partition_key=pk, parameters=parameters): - pass + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id # need to do query with no pk and with feed range elif operation == "changefeed": async for _ in container.query_items_change_feed(): pass + elif operation == "read_all_items": + async for item in container.read_all_items(partition_key=pk): + assert item['pk'] == pk + + + async def create_custom_transport_sm_mrr(self): From 8324a71e8ce47cf0501109b3fba0407cf058d65d Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 11 Apr 2025 15:36:28 -0700 Subject: [PATCH 079/152] Fix live tests with multi write locations --- .../azure-cosmos/tests/test_excluded_locations.py | 12 ++++++------ .../tests/test_excluded_locations_async.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 49b7f0553871..4b517796dd3a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -314,7 +314,7 @@ def test_replace_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_upsert_item(self, test_data): @@ -336,7 +336,7 @@ def test_upsert_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_create_item(self, test_data): @@ -355,7 +355,7 @@ def test_create_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_patch_item(self, test_data): @@ -383,7 +383,7 @@ def test_patch_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_execute_item_batch(self, test_data): @@ -412,7 +412,7 @@ def test_execute_item_batch(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_delete_item(self, test_data): @@ -439,7 +439,7 @@ def test_delete_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 50c0b69acd76..11ababfdfafd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -312,7 +312,7 @@ async def test_replace_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_upsert_item(self, test_data): @@ -334,7 +334,7 @@ async def test_upsert_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_create_item(self, test_data): @@ -353,7 +353,7 @@ async def test_create_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_patch_item(self, test_data): @@ -381,7 +381,7 @@ async def test_patch_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_execute_item_batch(self, test_data): @@ -410,7 +410,7 @@ async def test_execute_item_batch(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_delete_item(self, test_data): @@ -437,7 +437,7 @@ async def test_delete_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) if __name__ == "__main__": unittest.main() From b65f07d5b733fb067d7c58ad43d269f8997b49d8 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 11 Apr 2025 15:38:50 -0700 Subject: [PATCH 080/152] Fixed bug with endpoint routing with multi write region partition key API calls --- .../azure/cosmos/_cosmos_client_connection.py | 4 ++-- sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py | 9 +++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 2ad76e73766d..2de1dedf58e7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2192,8 +2192,8 @@ def DeleteAllItemsByPartitionKey( path = '{}{}/{}'.format(path, "operations", "partitionkeydelete") collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, - "partitionkey", documents._OperationType.Delete, options) - request_params = RequestObject("partitionkey", documents._OperationType.Delete) + http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options) + request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete) request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 02b293e29b4b..b2207b4431b6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -500,6 +500,7 @@ def can_use_multiple_write_locations(self): def can_use_multiple_write_locations_for_request(self, request): # pylint: disable=name-too-long return self.can_use_multiple_write_locations() and ( request.resource_type == http_constants.ResourceType.Document + or request.resource_type == http_constants.ResourceType.PartitionKey or ( request.resource_type == http_constants.ResourceType.StoredProcedure and request.operation_type == documents._OperationType.ExecuteJavaScript diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 185aa1d89cb8..e5b39c221016 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -22,7 +22,7 @@ """Represents a request object. """ from typing import Optional, Mapping, Any - +from . import http_constants class RequestObject(object): def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: @@ -57,7 +57,12 @@ def clear_route_to_location(self) -> None: def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: # If resource types for requests are not one of the followings, excluded locations cannot be set - if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey', 'colls']: + acceptable_resource_types = [ + http_constants.ResourceType.Document, + http_constants.ResourceType.PartitionKey, + http_constants.ResourceType.Collection, + ] + if self.resource_type.lower() not in acceptable_resource_types: return False # If 'excludedLocations' wasn't in the options, excluded locations cannot be set From 4a144d9dbe8ddf575eec9c9441607df28ccb6893 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 11 Apr 2025 15:41:39 -0700 Subject: [PATCH 081/152] Adding emulator tests for delete_all_items_by_partition_key API --- .../tests/_fault_injection_transport.py | 34 +++ .../tests/test_excluded_locations_emulator.py | 127 +++++++++ .../test_excluded_locations_emulator_async.py | 254 ++++++++++++++++++ 3 files changed, 415 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 628456d95158..0a0a81026ae7 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -228,6 +228,40 @@ def transform_topology_mwr( return response + @staticmethod + def transform_topology_mwr_with_url( + first_region_name: str, + first_region_url: str, + second_region_name: str, + second_region_url: str, + inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + + response = inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + data = response.body() + if response.status_code == 200 and data: + readable_locations = [ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, + {"name": second_region_name, "databaseAccountEndpoint": second_region_url} + ] + writeable_locations = [ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, + {"name": second_region_name, "databaseAccountEndpoint": second_region_url} + ] + + data = data.decode("utf-8") + result = json.loads(data) + result["readableLocations"] = readable_locations + result["writableLocations"] = writeable_locations + result["enableMultipleWriteLocations"] = True + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + class MockHttpResponse(RequestsTransportResponse): def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]): self.request: HttpRequest = request diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py new file mode 100644 index 000000000000..b39058aef6ba --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -0,0 +1,127 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import sys +import pytest +from typing import Callable, List, Mapping, Any + +from azure.core.rest import HttpRequest +from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos.container import ContainerProxy +from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA +from test_fault_injection_transport import TestFaultInjectionTransport + +logger = logging.getLogger('azure.cosmos') +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID + +L1_URL = test_config.TestConfig.local_host +L2_URL = L1_URL.replace("localhost", "127.0.0.1") +URL_TO_LOCATIONS = { + L1_URL: L1, + L2_URL: L2 +} + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def delete_all_items_by_partition_key_test_data() -> List[str]: + client_only_output_data = [ + L1, #0 + L2, #1 + L1, #3 + L1 #4 + ] + client_and_request_output_data = [ + L2, #0 + L2, #1 + L2, #2 + L1, #3 + L1, #4 + L1, #5 + L1, #6 + L1, #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + # all_output_test_data = client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def _get_location(initialized_objects: Mapping[str, Any]) -> str: + # get Request URL + header = initialized_objects['client'].client_connection.last_response_headers + request_url = header["_request"].url + + # verify + location = "" + for url in URL_TO_LOCATIONS: + if request_url.startswith(url): + location = URL_TO_LOCATIONS[url] + break + return location + +@pytest.mark.unittest +@pytest.mark.cosmosEmulator +class TestExcludedLocations: + @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) + def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_data: List[List[str]]): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + custom_transport = FaultInjectionTransport() + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr_with_url( + first_region_name=L1, + first_region_url=L1_URL, + second_region_name=L2, + second_region_url=L2_URL, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + for multiple_write_locations in [True, False]: + # Create client + initialized_objects = TestFaultInjectionTransport.setup_method_with_custom_transport( + custom_transport, + HOST, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations, + ) + container: ContainerProxy = initialized_objects["col"] + + # create an item + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, 'pk': id_value} + container.create_item(body=document_definition) + + # API call: delete_all_items_by_partition_key + if request_excluded_locations is None: + container.delete_all_items_by_partition_key(id_value) + else: + container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + actual_location = _get_location(initialized_objects) + if multiple_write_locations: + assert actual_location == expected_location + else: + assert actual_location == L1 + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py new file mode 100644 index 000000000000..f468c80649ca --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py @@ -0,0 +1,254 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest +import pytest_asyncio + +from azure.cosmos.aio import CosmosClient +from azure.cosmos.partition_key import PartitionKey +from test_excluded_locations import _verify_endpoint + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +L0 = "Default" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" + +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2], [], None], + # 1. Single excluded location + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2], [L1, L2], [L1, L2]], + # 5. No common excluded locations + [[L1, L2], [L1], [L2, L3]], + # 6. Request excluded location not in preferred locations + [[L1, L2], [L1, L2], [L3]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1], # 3 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def read_all_item_test_data(): + client_only_output_data = [ + [L1, L1], # 0 + [L2, L2], # 1 + [L1, L1], # 2 + [L1, L1], # 3 + ] + client_and_request_output_data = [ + [L2, L2], # 0 + [L2, L2], # 1 + [L2, L2], # 2 + [L1, L1], # 3 + [L1, L1], # 4 + [L1, L1], # 5 + [L1, L1], # 6 + [L1, L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 + ] + client_and_request_output_data = [ + [L1, L2, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L2, L2, L2], #2 + [L2, L1, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L1, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +async def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + await container.create_item(body=body) + else: + await container.create_item(body=body, excluded_locations=excluded_locations) + +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + test_client = CosmosClient(HOST, KEY) + container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + await container.upsert_item(body=TEST_ITEM) + + yield + await test_client.close() + +async def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = await client.create_database_if_not_exists(DATABASE_ID) + container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) + MOCK_HANDLER.reset() + + return client, db, container + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_and_teardown") +class TestExcludedLocations: + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + await _create_item_with_excluded_locations(container, body, request_excluded_locations) + MOCK_HANDLER.reset() + + # API call: delete_item + if request_excluded_locations is None: + await container.delete_item(item_id, ITEM_PK_VALUE) + else: + await container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(client, [L1]) + +if __name__ == "__main__": + unittest.main() From 9c68f753aaf551ff5e064f6d9740fc6ab3cae2fe Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 14 Apr 2025 10:49:10 -0700 Subject: [PATCH 082/152] minimized duplicate codes --- .../tests/_fault_injection_transport.py | 52 +++++-------------- .../tests/test_excluded_locations_emulator.py | 31 +++++------ .../tests/test_fault_injection_transport.py | 14 +++-- 3 files changed, 36 insertions(+), 61 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 0a0a81026ae7..9a99229ec995 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -203,7 +203,10 @@ def transform_topology_swr_mrr( def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + inner: Callable[[], RequestsTransportResponse], + first_region_url: str = None, + second_region_url: str = test_config.TestConfig.local_host + ) -> RequestsTransportResponse: response = inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): @@ -215,46 +218,17 @@ def transform_topology_mwr( result = json.loads(data) readable_locations = result["readableLocations"] writable_locations = result["writableLocations"] - readable_locations[0]["name"] = first_region_name - writable_locations[0]["name"] = first_region_name + + if first_region_url is None: + first_region_url = readable_locations[0]["databaseAccountEndpoint"] + readable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} + writable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} readable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) writable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) - result["enableMultipleWriteLocations"] = True - FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) - request: HttpRequest = response.request - return FaultInjectionTransport.MockHttpResponse(request, 200, result) - - return response - - @staticmethod - def transform_topology_mwr_with_url( - first_region_name: str, - first_region_url: str, - second_region_name: str, - second_region_url: str, - inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: - - response = inner() - if not FaultInjectionTransport.predicate_is_database_account_call(response.request): - return response - - data = response.body() - if response.status_code == 200 and data: - readable_locations = [ - {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, - {"name": second_region_name, "databaseAccountEndpoint": second_region_url} - ] - writeable_locations = [ - {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, - {"name": second_region_name, "databaseAccountEndpoint": second_region_url} - ] - - data = data.decode("utf-8") - result = json.loads(data) - result["readableLocations"] = readable_locations - result["writableLocations"] = writeable_locations + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) result["enableMultipleWriteLocations"] = True FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py index b39058aef6ba..7fcb558827f7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -1,11 +1,9 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import logging import unittest import uuid import test_config -import sys import pytest from typing import Callable, List, Mapping, Any @@ -15,15 +13,7 @@ from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA from test_fault_injection_transport import TestFaultInjectionTransport -logger = logging.getLogger('azure.cosmos') -logger.setLevel(logging.DEBUG) -logger.addHandler(logging.StreamHandler(sys.stdout)) - CONFIG = test_config.TestConfig() -HOST = CONFIG.host -KEY = CONFIG.masterKey -DATABASE_ID = CONFIG.TEST_DATABASE_ID -CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID L1_URL = test_config.TestConfig.local_host L2_URL = L1_URL.replace("localhost", "127.0.0.1") @@ -52,21 +42,22 @@ def delete_all_items_by_partition_key_test_data() -> List[str]: L1, #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data - # all_output_test_data = client_and_request_output_data all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data -def _get_location(initialized_objects: Mapping[str, Any]) -> str: +def _get_location( + initialized_objects: Mapping[str, Any], + url_to_locations: Mapping[str, str] = URL_TO_LOCATIONS) -> str: # get Request URL header = initialized_objects['client'].client_connection.last_response_headers request_url = header["_request"].url # verify location = "" - for url in URL_TO_LOCATIONS: + for url in url_to_locations: if request_url.startswith(url): - location = URL_TO_LOCATIONS[url] + location = url_to_locations[url] break return location @@ -84,12 +75,13 @@ def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_d is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ r: FaultInjectionTransport.predicate_is_database_account_call(r) emulator_as_multi_write_region_account_transformation = \ - lambda r, inner: FaultInjectionTransport.transform_topology_mwr_with_url( + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( first_region_name=L1, - first_region_url=L1_URL, second_region_name=L2, + inner=inner, + first_region_url=L1_URL, second_region_url=L2_URL, - inner=inner) + ) custom_transport.add_response_transformation( is_get_account_predicate, emulator_as_multi_write_region_account_transformation) @@ -98,7 +90,10 @@ def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_d # Create client initialized_objects = TestFaultInjectionTransport.setup_method_with_custom_transport( custom_transport, - HOST, + default_endpoint=CONFIG.host, + key=CONFIG.masterKey, + database_id=CONFIG.TEST_DATABASE_ID, + container_id=CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID, preferred_locations=preferred_locations, excluded_locations=client_excluded_locations, multiple_write_locations=multiple_write_locations, diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py index 304fa8d50f0d..4d7ea16ee58e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -62,11 +62,17 @@ def teardown_class(cls): logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) @staticmethod - def setup_method_with_custom_transport(custom_transport: RequestsTransport, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + def setup_method_with_custom_transport( + custom_transport: RequestsTransport, + default_endpoint: str = host, + key: str = master_key, + database_id: str = TEST_DATABASE_ID, + container_id: str = SINGLE_PARTITION_CONTAINER_NAME, + **kwargs): + client = CosmosClient(default_endpoint, key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) - db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(SINGLE_PARTITION_CONTAINER_NAME) + db: DatabaseProxy = client.get_database_client(database_id) + container: ContainerProxy = db.get_container_client(container_id) return {"client": client, "db": db, "col": container} From 225bc26133a7d944e6becf60fa3ac33f46c0db1b Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 14 Apr 2025 11:35:09 -0700 Subject: [PATCH 083/152] Added Async emulator tests --- .../tests/_fault_injection_transport_async.py | 18 +- .../tests/test_excluded_locations_emulator.py | 4 +- .../test_excluded_locations_emulator_async.py | 284 ++++-------------- .../test_fault_injection_transport_async.py | 34 ++- 4 files changed, 101 insertions(+), 239 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 13dda0dc7e20..6bdeb4ed49c9 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -201,7 +201,10 @@ async def transform_topology_swr_mrr( async def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + inner: Callable[[], Awaitable[AioHttpTransportResponse]], + first_region_url: str = None, + second_region_url: str = test_config.TestConfig.local_host + ) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): @@ -213,12 +216,17 @@ async def transform_topology_mwr( result = json.loads(data) readable_locations = result["readableLocations"] writable_locations = result["writableLocations"] - readable_locations[0]["name"] = first_region_name - writable_locations[0]["name"] = first_region_name + + if first_region_url is None: + first_region_url = readable_locations[0]["databaseAccountEndpoint"] + readable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} + writable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} readable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) writable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) result["enableMultipleWriteLocations"] = True FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py index 7fcb558827f7..96b3fc185afb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -46,7 +46,7 @@ def delete_all_items_by_partition_key_test_data() -> List[str]: all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data -def _get_location( +def get_location( initialized_objects: Mapping[str, Any], url_to_locations: Mapping[str, str] = URL_TO_LOCATIONS) -> str: # get Request URL @@ -112,7 +112,7 @@ def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_d container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) # Verify endpoint locations - actual_location = _get_location(initialized_objects) + actual_location = get_location(initialized_objects) if multiple_write_locations: assert actual_location == expected_location else: diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py index f468c80649ca..f706ca9e465e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py @@ -1,254 +1,100 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import logging import unittest import uuid import test_config import pytest -import pytest_asyncio +from typing import Callable, List, Mapping, Any -from azure.cosmos.aio import CosmosClient -from azure.cosmos.partition_key import PartitionKey -from test_excluded_locations import _verify_endpoint +from azure.core.pipeline.transport import AioHttpTransport +from _fault_injection_transport_async import FaultInjectionTransportAsync +from azure.cosmos.aio._container import ContainerProxy +from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA +from test_excluded_locations_emulator import L1_URL, L2_URL, get_location +from test_fault_injection_transport_async import TestFaultInjectionTransportAsync -class MockHandler(logging.Handler): - def __init__(self): - super(MockHandler, self).__init__() - self.messages = [] - - def reset(self): - self.messages = [] - - def emit(self, record): - self.messages.append(record.msg) - -MOCK_HANDLER = MockHandler() CONFIG = test_config.TestConfig() -HOST = CONFIG.host -KEY = CONFIG.masterKey -DATABASE_ID = CONFIG.TEST_DATABASE_ID -CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID -PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY -ITEM_ID = 'doc1' -ITEM_PK_VALUE = 'pk' -TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} - -L0 = "Default" -L1 = "West US 3" -L2 = "West US" -L3 = "East US 2" - -# L0 = "Default" -# L1 = "East US 2" -# L2 = "East US" -# L3 = "West US 2" - -CLIENT_ONLY_TEST_DATA = [ - # preferred_locations, client_excluded_locations, excluded_locations_request - # 0. No excluded location - [[L1, L2], [], None], - # 1. Single excluded location - [[L1, L2], [L1], None], - # 2. Exclude all locations - [[L1, L2], [L1, L2], None], - # 3. Exclude a location not in preferred locations - [[L1, L2], [L3], None], -] - -CLIENT_AND_REQUEST_TEST_DATA = [ - # preferred_locations, client_excluded_locations, excluded_locations_request - # 0. No client excluded locations + a request excluded location - [[L1, L2], [], [L1]], - # 1. The same client and request excluded location - [[L1, L2], [L1], [L1]], - # 2. Less request excluded locations - [[L1, L2], [L1, L2], [L1]], - # 3. More request excluded locations - [[L1, L2], [L1], [L1, L2]], - # 4. All locations were excluded - [[L1, L2], [L1, L2], [L1, L2]], - # 5. No common excluded locations - [[L1, L2], [L1], [L2, L3]], - # 6. Request excluded location not in preferred locations - [[L1, L2], [L1, L2], [L3]], - # 7. Empty excluded locations, remove all client excluded locations - [[L1, L2], [L1, L2], []], -] ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA -def read_item_test_data(): - client_only_output_data = [ - [L1], # 0 - [L2], # 1 - [L1], # 2 - [L1], # 3 - ] - client_and_request_output_data = [ - [L2], # 0 - [L2], # 1 - [L2], # 2 - [L1], # 3 - [L1], # 4 - [L1], # 5 - [L1], # 6 - [L1], # 7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def read_all_item_test_data(): +def delete_all_items_by_partition_key_test_data() -> List[str]: client_only_output_data = [ - [L1, L1], # 0 - [L2, L2], # 1 - [L1, L1], # 2 - [L1, L1], # 3 + L1, #0 + L2, #1 + L1, #3 + L1 #4 ] client_and_request_output_data = [ - [L2, L2], # 0 - [L2, L2], # 1 - [L2, L2], # 2 - [L1, L1], # 3 - [L1, L1], # 4 - [L1, L1], # 5 - [L1, L1], # 6 - [L1, L1], # 7 + L2, #0 + L2, #1 + L2, #2 + L1, #3 + L1, #4 + L1, #5 + L1, #6 + L1, #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data -def query_items_change_feed_test_data(): - client_only_output_data = [ - [L1, L1, L1, L1], #0 - [L2, L2, L2, L2], #1 - [L1, L1, L1, L1], #2 - [L1, L1, L1, L1] #3 - ] - client_and_request_output_data = [ - [L1, L2, L2, L2], #0 - [L2, L2, L2, L2], #1 - [L1, L2, L2, L2], #2 - [L2, L1, L1, L1], #3 - [L1, L1, L1, L1], #4 - [L2, L1, L1, L1], #5 - [L1, L1, L1, L1], #6 - [L1, L1, L1, L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def replace_item_test_data(): - client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #2 - [L1] #3 - ] - client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def patch_item_test_data(): - client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #2 - [L1] #3 - ] - client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -async def _create_item_with_excluded_locations(container, body, excluded_locations): - if excluded_locations is None: - await container.create_item(body=body) - else: - await container.create_item(body=body, excluded_locations=excluded_locations) - -@pytest_asyncio.fixture(scope="class", autouse=True) -async def setup_and_teardown(): - print("Setup: This runs before any tests") - logger = logging.getLogger("azure") - logger.addHandler(MOCK_HANDLER) - logger.setLevel(logging.DEBUG) - - test_client = CosmosClient(HOST, KEY) - container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) - await container.upsert_item(body=TEST_ITEM) - - yield - await test_client.close() - -async def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): - client = CosmosClient(HOST, KEY, - preferred_locations=preferred_locations, - excluded_locations=client_excluded_locations, - multiple_write_locations=multiple_write_locations) - db = await client.create_database_if_not_exists(DATABASE_ID) - container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) - MOCK_HANDLER.reset() - - return client, db, container - -@pytest.mark.cosmosMultiRegion +@pytest.mark.cosmosEmulator @pytest.mark.asyncio -@pytest.mark.usefixtures("setup_and_teardown") -class TestExcludedLocations: - @pytest.mark.parametrize('test_data', patch_item_test_data()) - async def test_delete_item(self, test_data): +class TestExcludedLocationsAsync: + @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) + async def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsAsync", test_data: List[List[str]]): # Init test variables - preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + custom_transport = FaultInjectionTransportAsync() + is_get_account_predicate: Callable[[AioHttpTransport], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name=L1, + first_region_url=L1_URL, + inner=inner, + second_region_name=L2, + second_region_url=L2_URL) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) for multiple_write_locations in [True, False]: - # Client setup - client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - - # create before delete - item_id = f'doc2-{str(uuid.uuid4())}' - body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} - await _create_item_with_excluded_locations(container, body, request_excluded_locations) - MOCK_HANDLER.reset() - - # API call: delete_item + # Create client + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( + custom_transport, + default_endpoint=CONFIG.host, + key=CONFIG.masterKey, + database_id=CONFIG.TEST_DATABASE_ID, + container_id=CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations, + ) + container: ContainerProxy = initialized_objects["col"] + + # create an item + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, 'pk': id_value} + await container.create_item(body=document_definition) + + # API call: delete_all_items_by_partition_key if request_excluded_locations is None: - await container.delete_item(item_id, ITEM_PK_VALUE) + await container.delete_all_items_by_partition_key(id_value) else: - await container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + await container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) # Verify endpoint locations + actual_location = get_location(initialized_objects) if multiple_write_locations: - _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + assert actual_location == expected_location else: - _verify_endpoint(client, [L1]) + assert actual_location == L1 if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 1df1de05936d..83535510c983 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -33,6 +33,7 @@ host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID +SINGLE_PARTITION_CONTAINER_NAME = os.path.basename(__file__) + str(uuid.uuid4()) @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -50,7 +51,7 @@ async def asyncSetUp(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") cls.database_id = TEST_DATABASE_ID - cls.single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) + cls.single_partition_container_name = SINGLE_PARTITION_CONTAINER_NAME cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) @@ -76,11 +77,18 @@ async def asyncTearDown(cls): except Exception as closeError: logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) - async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + @staticmethod + async def setup_method_with_custom_transport( + custom_transport: AioHttpTransport, + default_endpoint: str = host, + key: str = master_key, + database_id: str = TEST_DATABASE_ID, + container_id: str = SINGLE_PARTITION_CONTAINER_NAME, + **kwargs): + client = CosmosClient(default_endpoint, key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) - db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(self.single_partition_container_name) + db: DatabaseProxy = client.get_database_client(database_id) + container: ContainerProxy = db.get_container_client(container_id) return {"client": client, "db": db, "col": container} @staticmethod @@ -106,7 +114,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy status_code=502, message="Some random reverse proxy error.")))) - initialized_objects = await self.setup_method_with_custom_transport(custom_transport) + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport(custom_transport) start: float = time.perf_counter() try: container: ContainerProxy = initialized_objects["col"] @@ -151,7 +159,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["Read Region", "Write Region"]) try: @@ -210,7 +218,7 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -275,7 +283,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -320,7 +328,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -373,7 +381,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -445,7 +453,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -498,7 +506,7 @@ async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsyn 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: From 5f2c5a08276db59a215f7b2e0feabf3f41800942 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 14 Apr 2025 12:19:16 -0700 Subject: [PATCH 084/152] Nit: Changed test names --- .../azure-cosmos/tests/test_excluded_locations_emulator.py | 6 +++--- .../tests/test_excluded_locations_emulator_async.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py index 96b3fc185afb..375bcfc899d8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -63,14 +63,14 @@ def get_location( @pytest.mark.unittest @pytest.mark.cosmosEmulator -class TestExcludedLocations: +class TestExcludedLocationsEmulator: @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) - def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_data: List[List[str]]): + def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsEmulator", test_data: List[List[str]]): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data # Inject topology transformation that would make Emulator look like a multiple write region account - # account with two read regions + # with two read regions custom_transport = FaultInjectionTransport() is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ r: FaultInjectionTransport.predicate_is_database_account_call(r) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py index f706ca9e465e..c24c2f13c9f7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py @@ -42,14 +42,14 @@ def delete_all_items_by_partition_key_test_data() -> List[str]: @pytest.mark.cosmosEmulator @pytest.mark.asyncio -class TestExcludedLocationsAsync: +class TestExcludedLocationsEmulatorAsync: @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) - async def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsAsync", test_data: List[List[str]]): + async def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsEmulatorAsync", test_data: List[List[str]]): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data # Inject topology transformation that would make Emulator look like a multiple write region account - # account with two read regions + # with two read regions custom_transport = FaultInjectionTransportAsync() is_get_account_predicate: Callable[[AioHttpTransport], bool] = lambda \ r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) From c3e39e522415d4767e8c4a503da42f483b1bf217 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 15 Apr 2025 14:06:39 -0700 Subject: [PATCH 085/152] Addressed comments about documents --- .../azure-cosmos/azure/cosmos/documents.py | 2 +- .../samples/excluded_locations.py | 25 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 9e04829be52f..7ccc99da9dfe 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -314,7 +314,7 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes set of locations from the final location evaluation. The locations in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US', 'Central India' and so on. - :vartype ExcludedLocations: ~CosmosExcludedLocations + :vartype ExcludedLocations: List[str] :ivar RetryOptions: Gets or sets the retry options to be applied to all requests when retrying. diff --git a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py index 06228c1a8cea..a8c699a7cccf 100644 --- a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py @@ -15,6 +15,9 @@ # # 2. Microsoft Azure Cosmos # pip install azure-cosmos>=4.3.0b4 +# +# 3. Configure Azure Cosmos account to add 3+ regions, such as 'West US 3', 'West US', 'East US 2'. +# If you added other regions, update L1~L3 with the regions in your account. # ---------------------------------------------------------------------------------------------------------- # Sample - demonstrates how to use excluded locations in client level and request level # ---------------------------------------------------------------------------------------------------------- @@ -33,12 +36,14 @@ DATABASE_ID = config.settings["database_id"] CONTAINER_ID = config.settings["container_id"] -PARTITION_KEY = PartitionKey(path="/id") +PARTITION_KEY = PartitionKey(path="/pk") +L1, L2, L3 = 'West US 3', 'West US', 'East US 2' def get_test_item(num): test_item = { 'id': 'Item_' + str(num), + 'pk': 'PartitionKey_' + str(num), 'test_object': True, 'lastName': 'Smith' } @@ -51,8 +56,8 @@ def clean_up_db(client): pass def excluded_locations_client_level_sample(): - preferred_locations = ['West US 3', 'West US', 'East US 2'] - excluded_locations = ['West US 3', 'West US'] + preferred_locations = [L1, L2, L3] + excluded_locations = [L1, L2] client = CosmosClient( HOST, MASTER_KEY, @@ -66,19 +71,19 @@ def excluded_locations_client_level_sample(): # For write operations with single master account, write endpoint will be the default endpoint, # since preferred_locations or excluded_locations are ignored and used - container.create_item(get_test_item(0)) + created_item = container.create_item(get_test_item(0)) # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. # In our sample, ['West US 3', 'West US', 'East US 2'] - ['West US 3', 'West US'] => ['East US 2'], # therefore 'East US 2' will be the read endpoint, and items will be read from 'East US 2' location - item = container.read_item(item='Item_0', partition_key='Item_0') + item = container.read_item(item=created_item['id'], partition_key=created_item['pk']) clean_up_db(client) def excluded_locations_request_level_sample(): - preferred_locations = ['West US 3', 'West US', 'East US 2'] - excluded_locations_on_client = ['West US 3', 'West US'] - excluded_locations_on_request = ['West US 3'] + preferred_locations = [L1, L2, L3] + excluded_locations_on_client = [L1, L2] + excluded_locations_on_request = [L1] client = CosmosClient( HOST, MASTER_KEY, @@ -92,7 +97,7 @@ def excluded_locations_request_level_sample(): # For write operations with single master account, write endpoint will be the default endpoint, # since preferred_locations or excluded_locations are ignored and used - container.create_item(get_test_item(0)) + created_item = container.create_item(get_test_item(0), excluded_locations=excluded_locations_on_request) # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. # However, in our sample, since the excluded_locations` were passed with the read request, the `excluded_location` @@ -101,7 +106,7 @@ def excluded_locations_request_level_sample(): # With the excluded_locations on request, the read endpoints will be ['West US', 'East US 2'] # ['West US 3', 'West US', 'East US 2'] - ['West US 3'] => ['West US', 'East US 2'] # Therefore, items will be read from 'West US' or 'East US 2' location - item = container.read_item(item='Item_0', partition_key='Item_0', excluded_locations=excluded_locations_on_request) + item = container.read_item(item=created_item['id'], partition_key=created_item['pk'], excluded_locations=excluded_locations_on_request) clean_up_db(client) From 39a464cb4a5f10c73557fac5dd685e9fcd15becc Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 16 Apr 2025 17:10:58 -0700 Subject: [PATCH 086/152] live tests --- .../azure/cosmos/_partition_health_tracker.py | 8 +- .../azure/cosmos/_request_object.py | 4 +- .../azure-cosmos/tests/test_ppcb_mm_async.py | 408 ++++++++++++++++++ .../tests/test_ppcb_sm_mrr_async.py | 2 + 4 files changed, 416 insertions(+), 6 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 9f30bac2bd2c..034d1cf1ac10 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -23,7 +23,7 @@ """ import logging import os -from typing import Dict, Set, Any +from typing import Dict, Set, Any, List from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import current_time_millis, EndpointOperationType from ._constants import _Constants as Constants @@ -171,15 +171,15 @@ def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper self._reset_partition_health_tracker_stats() - def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> Set[str]: + def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> List[str]: self._check_stale_partition_info(pk_range_wrapper) - excluded_locations = set() + excluded_locations = [] if pk_range_wrapper in self.pk_range_wrapper_to_health_info: for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): if partition_health_info.unavailability_info: health_status = partition_health_info.unavailability_info[HEALTH_STATUS] if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY): - excluded_locations.add(location) + excluded_locations.append(location) return excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 253434c8b881..24dc9e0dd9c2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -42,7 +42,7 @@ def __init__( self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None self.excluded_locations: Optional[List[str]] = None - self.excluded_locations_circuit_breaker: Set[str] = set() + self.excluded_locations_circuit_breaker: List[str] = [] def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -89,5 +89,5 @@ def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None if self._can_set_excluded_location(options): self.excluded_locations = options['excludedLocations'] - def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: # pylint: disable=name-too-long + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: # pylint: disable=name-too-long self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py new file mode 100644 index 000000000000..2d033890429d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -0,0 +1,408 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +import pytest_asyncio +from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey, _location_cache +from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport_async import FaultInjectionTransportAsync + +REGION_1 = "West US 3" +REGION_2 = "Mexico Central" # "West US" + + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, TestPerPartitionCircuitBreakerMMAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) + await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +def write_operations_and_errors(): + write_operations = ["create", "upsert", "replace", "delete", "patch", "batch", "delete_all_items_by_partition_key"] + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for write_operation in write_operations: + for error in errors: + params.append((write_operation, error)) + + return params + +def read_operations_and_errors(): + read_operations = ["read", "query", "changefeed", "read_all_items"] + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for read_operation in read_operations: + for error in errors: + params.append((read_operation, error)) + + return params + +def validate_response_uri(response, expected_uri): + request = response.get_response_headers()["_request"] + assert request.url.startswith(expected_uri) + +async def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == "create": + resp = await fault_injection_container.create_item(body=doc) + elif operation == "upsert": + resp = await fault_injection_container.upsert_item(body=doc) + elif operation == "replace": + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) + elif operation == "delete": + await container.create_item(body=doc) + resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == "patch": + operations = [{"op": "incr", "path": "/company", "value": 3}] + resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) + elif operation == "batch": + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) + elif operation == "delete_all_items_by_partition_key": + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.create_item(body=doc) + resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + validate_response_uri(resp, expected_uri) + + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPerPartitionCircuitBreakerMMAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=[REGION_1, REGION_2], + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + async def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + await method_client.close() + + @staticmethod + async def perform_read_operation(operation, container, doc_id, pk, expected_uri): + if operation == "read": + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + elif operation == "query": + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == "changefeed": + async for _ in container.query_items_change_feed(): + pass + elif operation == "read_all_items": + async for item in container.read_all_items(partition_key=pk): + assert item['pk'] == pk + + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_consecutive_failure_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = FaultInjectionTransportAsync() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + + await perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(6): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_consecutive_failure_threshold_async(self, setup, read_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = FaultInjectionTransportAsync() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + for i in range(10): + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + # test recovering the partition again + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_failure_rate_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_read_failure_rate_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + for i in range(20): + if i == 8: + read_resp = await container.read_item(item=doc_2['id'], + partition_key=doc_2['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + else: + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + # look at the urls for verifying fall back and use another id for same partition + + @staticmethod + def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + + assert unhealthy_partitions == expected_unhealthy_partitions + + # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again + # test service request marks only a partition unavailable not an entire region - across operation types + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index ee5e3fcd6fa9..ec6cf5ecff7d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -297,6 +297,8 @@ def validate_unhealthy_partitions(global_endpoint_manager, health_status = health_info.unavailability_info.get(HEALTH_STATUS) if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: unhealthy_partitions += 1 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 else: assert health_info.read_consecutive_failure_count < 10 assert health_info.write_failure_count == 0 From d70343c7b23097dac2b51d0800dd5028e4a1db44 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 17 Apr 2025 11:30:30 -0700 Subject: [PATCH 087/152] fix tests --- .../azure/cosmos/aio/_container.py | 21 ++-- .../aio/_cosmos_client_connection_async.py | 36 +++--- .../azure/cosmos/aio/_retry_utility_async.py | 3 + .../azure-cosmos/tests/test_ppcb_mm_async.py | 114 ++++++++++-------- 4 files changed, 103 insertions(+), 71 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 360ed38af64b..ddf0daa77585 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -273,8 +273,10 @@ async def create_item( request_options["disableAutomaticIdGeneration"] = not enable_automatic_id_generation if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + if not self.container_link in self.__get_client_container_caches(): + container = await self.read() + self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.CreateItem( database_or_container_link=self.container_link, document=body, options=request_options, **kwargs @@ -393,6 +395,7 @@ def read_all_items( response_hook.clear() if self.container_link in self.__get_client_container_caches(): feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + kwargs["containerProperties"] = self._get_properties items = self.client_connection.ReadItems( collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs @@ -493,7 +496,6 @@ def query_items( feed_options["enableScanInQuery"] = enable_scan_in_query if partition_key is not None: feed_options["partitionKey"] = self._set_partition_key(partition_key) - kwargs["containerProperties"] = self._get_properties else: feed_options["enableCrossPartitionQuery"] = True if max_integrated_cache_staleness_in_ms: @@ -507,6 +509,7 @@ def query_items( response_hook.clear() if self.container_link in self.__get_client_container_caches(): feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + kwargs["containerProperties"] = self._get_properties items = self.client_connection.QueryItems( database_or_container_link=self.container_link, @@ -803,8 +806,10 @@ async def upsert_item( kwargs['no_response'] = no_response request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + if not self.container_link in self.__get_client_container_caches(): + container = await self.read() + self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.UpsertItem( database_or_container_link=self.container_link, @@ -880,8 +885,10 @@ async def replace_item( kwargs['no_response'] = no_response request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + if not self.container_link in self.__get_client_container_caches(): + container = await self.read() + self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.ReplaceItem( document_link=item_link, new_document=body, options=request_options, **kwargs diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 31ae9dc334cd..52c4e2f0826b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2853,7 +2853,7 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, :raises SystemError: If the query compatibility mode is undefined. """ if options is None: - options = {} + options: Dict[str, Any] = {} if query: __GetBodiesFromQueryResult = result_fn @@ -2867,6 +2867,14 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: return [] initial_headers = self.default_headers.copy() + + + cont_prop = kwargs.pop("containerProperties", None) + + if cont_prop: + cont_prop = await cont_prop() + options["containerRID"] = cont_prop["_rid"] + # Copy to make sure that default_headers won't be changed. if query is None: op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed @@ -2904,6 +2912,18 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: else: raise SystemError("Unexpected query compatibility mode.") + # check if query has prefix partition key + partition_key = options.get("partitionKey", None) + isPrefixPartitionQuery = False + partition_key_definition = None + if cont_prop and partition_key: + pk_properties = cont_prop["partitionKey"] + partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"]) + if partition_key_definition.kind == "MultiHash" and \ + (isinstance(partition_key, List) and \ + len(partition_key_definition['paths']) != len(partition_key)): + isPrefixPartitionQuery = True + # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, documents._OperationType.SqlQuery, @@ -2911,20 +2931,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) request_params.set_excluded_location_from_options(options) - # check if query has prefix partition key - cont_prop = kwargs.pop("containerProperties", None) - partition_key = options.get("partitionKey", None) - isPrefixPartitionQuery = False - partition_key_definition = None - if cont_prop: - cont_prop = await cont_prop() - pk_properties = cont_prop["partitionKey"] - partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"]) - if partition_key_definition.kind == "MultiHash" and \ - (isinstance(partition_key, List) and \ - len(partition_key_definition['paths']) != len(partition_key)): - isPrefixPartitionQuery = True - if isPrefixPartitionQuery and partition_key_definition: # here get the overlapping ranges req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 5d4d680b50e4..be93acc64d1f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -224,9 +224,11 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if isinstance(e.inner_exception, ClientConnectionError): _handle_service_request_retries(client, service_request_retry_policy, e, *args) else: + await global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) # in case customer is not using aiohttp except ImportError: + await global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) @@ -270,6 +272,7 @@ async def send(self, request): start_time = time.time() try: _configure_timeout(request, absolute_timeout, per_request_timeout) + print("RetryUtility - Sending request") response = await self.next.send(request) break except ClientAuthenticationError: # pylint:disable=try-except-raise diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 2d033890429d..01f39f1e7b70 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -4,7 +4,7 @@ import os import unittest import uuid -from typing import Dict, Any +from typing import Dict, Any, List import pytest import pytest_asyncio @@ -23,25 +23,22 @@ COLLECTION = "created_collection" -@pytest_asyncio.fixture() -async def setup(): +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" - client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, TestPerPartitionCircuitBreakerMMAsync.master_key, consistency_level="Session") + client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, + TestPerPartitionCircuitBreakerMMAsync.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) - await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) - created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + await created_database.create_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), offer_throughput=10000) - yield { - COLLECTION: created_collection - } - + yield await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" def write_operations_and_errors(): - write_operations = ["create", "upsert", "replace", "delete", "patch", "batch", "delete_all_items_by_partition_key"] + write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] # "delete_all_items_by_partition_key"] errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -86,13 +83,13 @@ async def perform_write_operation(operation, container, fault_injection_containe elif operation == "upsert": resp = await fault_injection_container.upsert_item(body=doc) elif operation == "replace": + await container.create_item(body=doc) new_doc = {'id': doc_id, 'pk': pk, 'name': 'sample document' + str(uuid), 'key': 'value'} resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) elif operation == "delete": - await container.create_item(body=doc) resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) elif operation == "patch": operations = [{"op": "incr", "path": "/company", "value": 3}] @@ -105,26 +102,26 @@ async def perform_write_operation(operation, container, fault_injection_containe ("upsert", (doc,)), ] resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) - elif operation == "delete_all_items_by_partition_key": - await container.create_item(body=doc) - await container.create_item(body=doc) - await container.create_item(body=doc) - resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + # this will need to be emulator only + # elif operation == "delete_all_items_by_partition_key": + # await container.create_item(body=doc) + # await container.create_item(body=doc) + # await container.create_item(body=doc) + # resp = await fault_injection_container.delete_all_items_by_partition_key(pk) validate_response_uri(resp, expected_uri) @pytest.mark.cosmosMultiRegion @pytest.mark.asyncio -@pytest.mark.usefixtures("setup") +@pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMMAsync: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey - connectionPolicy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + client = CosmosClient(default_endpoint, self.master_key, preferred_locations=[REGION_1, REGION_2], multiple_write_locations=True, transport=custom_transport, **kwargs) @@ -132,10 +129,20 @@ async def setup_method_with_custom_transport(self, custom_transport: AioHttpTran container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) return {"client": client, "db": db, "col": container} + async def setup_method(self, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, + preferred_locations=[REGION_1, REGION_2], + multiple_write_locations=True, + **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + @staticmethod - async def cleanup_method(initialized_objects: Dict[str, Any]): - method_client: CosmosClient = initialized_objects["client"] - await method_client.close() + async def cleanup_method(initialized_objects: List[Dict[str, Any]]): + for obj in initialized_objects: + method_client: CosmosClient = obj["client"] + await method_client.close() @staticmethod async def perform_read_operation(operation, container, doc_id, pk, expected_uri): @@ -159,8 +166,9 @@ async def perform_read_operation(operation, container, doc_id, pk, expected_uri) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_write_consecutive_failure_threshold_async(self, setup, write_operation, error): + async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) custom_transport = FaultInjectionTransportAsync() id_value = 'failoverDoc-' + str(uuid.uuid4()) document_definition = {'id': id_value, @@ -168,40 +176,49 @@ async def test_write_consecutive_failure_threshold_async(self, setup, write_oper 'name': 'sample document', 'key': 'value'} predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - container = custom_setup['col'] - global_endpoint_manager = container.client_connection._global_endpoint_manager - - - await perform_write_operation(write_operation, - setup[COLLECTION], - container, - document_definition['id'], - document_definition['pk'], - expected_uri) + setup_teardown = await self.setup_method(custom_transport, default_endpoint=self.host) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await perform_write_operation(write_operation, + container, + fault_injection_container, + document_definition['id'], + document_definition['pk'], + expected_uri) TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # writes should fail but still be tracked - for i in range(6): + for i in range(4): with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): await perform_write_operation(write_operation, setup[COLLECTION], - container, + fault_injection_container, document_definition['id'], document_definition['pk'], expected_uri) + await perform_write_operation(write_operation, + setup[COLLECTION], + fault_injection_container, + document_definition['id'], + document_definition['pk'], + expected_uri) TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + TestPerPartitionCircuitBreakerMMAsync.cleanup_method(custom_setup) @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - async def test_read_consecutive_failure_threshold_async(self, setup, read_operation, error): + async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) custom_transport = FaultInjectionTransportAsync() id_value = 'failoverDoc-' + str(uuid.uuid4()) @@ -222,7 +239,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup, read_operat await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup[COLLECTION], + setup_teardown[COLLECTION], container, document_definition['id'], document_definition['pk'], @@ -232,7 +249,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup, read_operat for i in range(6): with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup[COLLECTION], + setup_teardown[COLLECTION], container, document_definition['id'], document_definition['pk'], @@ -241,7 +258,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup, read_operat TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) # create item with client without fault injection - await setup[COLLECTION].create_item(body=document_definition) + await setup_teardown[COLLECTION].create_item(body=document_definition) # reads should fail over and only the relevant partition should be marked as unavailable await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, @@ -264,7 +281,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup, read_operat # test recovering the partition again @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_failure_rate_threshold_async(self, setup, write_operation, error): + async def test_failure_rate_threshold_async(self, setup_teardown, write_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) custom_transport = await self.create_custom_transport_sm_mrr() id_value = 'failoverDoc-' + str(uuid.uuid4()) @@ -297,7 +314,7 @@ async def test_failure_rate_threshold_async(self, setup, write_operation, error) await container.upsert_item(body=doc_2) with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup[COLLECTION], + setup_teardown[COLLECTION], container, document_definition['id'], document_definition['pk'], @@ -306,7 +323,7 @@ async def test_failure_rate_threshold_async(self, setup, write_operation, error) TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection - await setup[COLLECTION].create_item(body=document_definition) + await setup_teardown[COLLECTION].create_item(body=document_definition) finally: # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 @@ -314,7 +331,7 @@ async def test_failure_rate_threshold_async(self, setup, write_operation, error) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_read_failure_rate_threshold_async(self, setup, write_operation, error): + async def test_read_failure_rate_threshold_async(self, setup_teardown, write_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) custom_transport = await self.create_custom_transport_sm_mrr() id_value = 'failoverDoc-' + str(uuid.uuid4()) @@ -347,7 +364,7 @@ async def test_read_failure_rate_threshold_async(self, setup, write_operation, e await container.upsert_item(body=doc_2) with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup[COLLECTION], + setup_teardown[COLLECTION], container, document_definition['id'], document_definition['pk'], @@ -356,7 +373,7 @@ async def test_read_failure_rate_threshold_async(self, setup, write_operation, e TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection - await setup[COLLECTION].create_item(body=document_definition) + await setup_teardown[COLLECTION].create_item(body=document_definition) # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) @@ -395,8 +412,7 @@ def validate_unhealthy_partitions(global_endpoint_manager, unhealthy_partitions += 1 else: assert health_info.read_consecutive_failure_count < 10 - assert health_info.write_failure_count == 0 - assert health_info.write_consecutive_failure_count == 0 + assert health_info.write_consecutive_failure_count < 5 assert unhealthy_partitions == expected_unhealthy_partitions From 137a13044732c801b58c2ab432c5d80ed631bf1b Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 17 Apr 2025 15:50:36 -0700 Subject: [PATCH 088/152] add container rid --- .../azure/cosmos/aio/_container.py | 24 ++++++++++++------- .../azure/cosmos/aio/_retry_utility_async.py | 2 +- .../azure-cosmos/tests/test_ppcb_mm_async.py | 17 ++++++------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index ddf0daa77585..d9747fec6b51 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -965,8 +965,10 @@ async def patch_item( request_options["partitionKey"] = await self._set_partition_key(partition_key) if filter_predicate is not None: request_options["filterPredicate"] = filter_predicate - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + if not self.container_link in self.__get_client_container_caches(): + container = await self.read() + self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) result = await self.client_connection.PatchItem( @@ -1031,8 +1033,10 @@ async def delete_item( kwargs['priority'] = priority request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + if not self.container_link in self.__get_client_container_caches(): + container = await self.read() + self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] document_link = self._get_document_link(item) await self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs) @@ -1286,8 +1290,10 @@ async def delete_all_items_by_partition_key( request_options = _build_options(kwargs) # regardless if partition key is valid we set it as invalid partition keys are set to a default empty value request_options["partitionKey"] = await self._set_partition_key(partition_key) - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + if not self.container_link in self.__get_client_container_caches(): + container = await self.read() + self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] await self.client_connection.DeleteAllItemsByPartitionKey(collection_link=self.container_link, options=request_options, **kwargs) @@ -1348,8 +1354,10 @@ async def execute_item_batch( request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + if not self.container_link in self.__get_client_container_caches(): + container = await self.read() + self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return await self.client_connection.Batch( collection_link=self.container_link, batch_operations=batch_operations, options=request_options, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index be93acc64d1f..c2f12d7185bc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -272,7 +272,7 @@ async def send(self, request): start_time = time.time() try: _configure_timeout(request, absolute_timeout, per_request_timeout) - print("RetryUtility - Sending request") + print("RetryUtility - Sending request ", request_params.operation_type, request_params.resource_type) response = await self.next.send(request) break except ClientAuthenticationError: # pylint:disable=try-except-raise diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 01f39f1e7b70..9acaca81e0a2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -90,8 +90,10 @@ async def perform_write_operation(operation, container, fault_injection_containe 'key': 'value'} resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) elif operation == "delete": + await container.create_item(body=doc) resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) elif operation == "patch": + await container.create_item(body=doc) operations = [{"op": "incr", "path": "/company", "value": 3}] resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) elif operation == "batch": @@ -172,9 +174,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w custom_transport = FaultInjectionTransportAsync() id_value = 'failoverDoc-' + str(uuid.uuid4()) document_definition = {'id': id_value, - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} + 'pk': 'pk1'} predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( @@ -183,18 +183,19 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - setup_teardown = await self.setup_method(custom_transport, default_endpoint=self.host) + setup = await self.setup_method(default_endpoint=self.host) container = setup['col'] fault_injection_container = custom_setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: await perform_write_operation(write_operation, container, fault_injection_container, document_definition['id'], document_definition['pk'], expected_uri) + assert exc_info.value == error TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) @@ -202,20 +203,20 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w for i in range(4): with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): await perform_write_operation(write_operation, - setup[COLLECTION], + container, fault_injection_container, document_definition['id'], document_definition['pk'], expected_uri) await perform_write_operation(write_operation, - setup[COLLECTION], + container, fault_injection_container, document_definition['id'], document_definition['pk'], expected_uri) TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) - TestPerPartitionCircuitBreakerMMAsync.cleanup_method(custom_setup) + await TestPerPartitionCircuitBreakerMMAsync.cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): From 08e9e45db66e562dbdcdf7c3f7a9e0bcf71becc8 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 17 Apr 2025 17:05:18 -0700 Subject: [PATCH 089/152] fix mm tests --- .../azure/cosmos/aio/_retry_utility_async.py | 11 -- .../azure-cosmos/tests/test_ppcb_mm_async.py | 124 ++++++++---------- 2 files changed, 55 insertions(+), 80 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index c2f12d7185bc..a9cdff17703a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -317,17 +317,6 @@ async def send(self, request): except ImportError: raise err # pylint: disable=raise-missing-from raise err - except AzureError as err: - retry_error = err - if _has_database_account_header(request.http_request.headers): - raise err - if self._is_method_retryable(retry_settings, request.http_request): - await global_endpoint_manager.record_failure(request_params) - retry_active = self.increment(retry_settings, response=request, error=err) - if retry_active: - await self.sleep(retry_settings, request.context.transport) - continue - raise err finally: end_time = time.time() if absolute_timeout: diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 9acaca81e0a2..b8bce3a4d019 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -32,6 +32,8 @@ async def setup_teardown(): await created_database.create_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), offer_throughput=10000) + # allow some time for the container to be created as this method is in different event loop + await asyncio.sleep(2) yield await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() @@ -110,7 +112,27 @@ async def perform_write_operation(operation, container, fault_injection_containe # await container.create_item(body=doc) # await container.create_item(body=doc) # resp = await fault_injection_container.delete_all_items_by_partition_key(pk) - validate_response_uri(resp, expected_uri) + if resp: + validate_response_uri(resp, expected_uri) + +async def perform_read_operation(operation, container, doc_id, pk, expected_uri): + if operation == "read": + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + elif operation == "query": + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == "changefeed": + async for _ in container.query_items_change_feed(): + pass + elif operation == "read_all_items": + async for item in container.read_all_items(partition_key=pk): + assert item['pk'] == pk @pytest.mark.cosmosMultiRegion @@ -146,35 +168,12 @@ async def cleanup_method(initialized_objects: List[Dict[str, Any]]): method_client: CosmosClient = obj["client"] await method_client.close() - @staticmethod - async def perform_read_operation(operation, container, doc_id, pk, expected_uri): - if operation == "read": - read_resp = await container.read_item(item=doc_id, partition_key=pk) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_uri) - elif operation == "query": - query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" - parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] - async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): - assert item['id'] == doc_id - # need to do query with no pk and with feed range - elif operation == "changefeed": - async for _ in container.query_items_change_feed(): - pass - elif operation == "read_all_items": - async for item in container.read_all_items(partition_key=pk): - assert item['pk'] == pk - - @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) custom_transport = FaultInjectionTransportAsync() - id_value = 'failoverDoc-' + str(uuid.uuid4()) - document_definition = {'id': id_value, - 'pk': 'pk1'} + pk_value = "pk1" predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( @@ -192,8 +191,8 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w await perform_write_operation(write_operation, container, fault_injection_container, - document_definition['id'], - document_definition['pk'], + str(uuid.uuid4()), + pk_value, expected_uri) assert exc_info.value == error @@ -205,24 +204,27 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w await perform_write_operation(write_operation, container, fault_injection_container, - document_definition['id'], - document_definition['pk'], + str(uuid.uuid4()), + pk_value, expected_uri) await perform_write_operation(write_operation, container, fault_injection_container, - document_definition['id'], - document_definition['pk'], + str(uuid.uuid4()), + pk_value, expected_uri) TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) await TestPerPartitionCircuitBreakerMMAsync.cleanup_method([custom_setup, setup]) + # test recovering the partition + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) custom_transport = FaultInjectionTransportAsync() - id_value = 'failoverDoc-' + str(uuid.uuid4()) + id_value = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': 'pk1', 'name': 'sample document', @@ -235,51 +237,35 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - container = custom_setup['col'] - global_endpoint_manager = container.client_connection._global_endpoint_manager - - - await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup_teardown[COLLECTION], - container, - document_definition['id'], - document_definition['pk'], - expected_uri) - - # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition - for i in range(6): - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup_teardown[COLLECTION], - container, - document_definition['id'], - document_definition['pk'], - expected_uri) - - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) - - # create item with client without fault injection - await setup_teardown[COLLECTION].create_item(body=document_definition) + fault_injection_container = custom_setup['col'] + setup = await self.setup_method(default_endpoint=self.host) + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # create some documents + await container.create_item(body=document_definition) + for i in range(5): + document_definition['id'] = str(uuid.uuid4()) + await container.create_item(body=document_definition) # reads should fail over and only the relevant partition should be marked as unavailable - await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, - container, - document_definition['id'], - document_definition['pk'], - expected_uri) + await perform_read_operation(read_operation, + fault_injection_container, + document_definition['id'], + document_definition['pk'], + expected_uri) # partition should not have been marked unavailable after one error - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(10): - await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, - container, - document_definition['id'], - document_definition['pk'], - expected_uri) + await perform_read_operation(read_operation, + fault_injection_container, + document_definition['id'], + document_definition['pk'], + expected_uri) # the partition should have been marked as unavailable after breaking read threshold TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) - # test recovering the partition again @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_failure_rate_threshold_async(self, setup_teardown, write_operation, error): From fdba91433764f0ca7f02f5b4c50c565b1ac9fef7 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Fri, 18 Apr 2025 12:57:01 -0700 Subject: [PATCH 090/152] test improvements --- .../azure/cosmos/_partition_health_tracker.py | 10 +- .../azure/cosmos/aio/_container.py | 32 +-- ..._endpoint_manager_circuit_breaker_async.py | 1 + .../azure-cosmos/tests/test_ppcb_mm_async.py | 257 +++++++++--------- 4 files changed, 150 insertions(+), 150 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 034d1cf1ac10..9760a0540387 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -48,9 +48,12 @@ def _has_exceeded_failure_rate_threshold( failures: int, failure_rate_threshold: int, ) -> bool: + print(MINIMUM_REQUESTS_FOR_FAILURE_RATE) if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: return False - return (failures / successes * 100) >= failure_rate_threshold + failure_rate = failures / (failures + successes) * 100 + print("Failure rate", failure_rate) + return failure_rate >= failure_rate_threshold class _PartitionHealthInfo(object): """ @@ -97,8 +100,6 @@ def __init__(self) -> None: self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} self.last_refresh = current_time_millis() - # TODO: @tvaron3 look for useful places to add logs - def mark_partition_unavailable(self, pk_range_wrapper: PartitionKeyRangeWrapper, location: str) -> None: # mark the partition key range as unavailable self._transition_health_status_on_failure(pk_range_wrapper, location) @@ -200,6 +201,7 @@ def add_failure( self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] + print(failure_rate_threshold) # Determine attribute names and environment variables based on the operation type. if operation_type == EndpointOperationType.WriteType: @@ -246,7 +248,7 @@ def _check_thresholds( failure_rate_threshold: int, consecutive_failure_threshold: int, ) -> None: - + print("Check Thresholds called") # check the failure rate was not exceeded if _has_exceeded_failure_rate_threshold( successes, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 6a103072c640..12fa8dfefc74 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -272,9 +272,7 @@ async def create_item( request_options["disableAutomaticIdGeneration"] = not enable_automatic_id_generation if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive - if not self.container_link in self.__get_client_container_caches(): - container = await self.read() - self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + await self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.CreateItem( @@ -345,8 +343,8 @@ async def read_item( if max_integrated_cache_staleness_in_ms is not None: validate_cache_staleness_value(max_integrated_cache_staleness_in_ms) request_options["maxIntegratedCacheStaleness"] = max_integrated_cache_staleness_in_ms - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + await self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return await self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) @@ -823,9 +821,7 @@ async def upsert_item( kwargs['no_response'] = no_response request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - if not self.container_link in self.__get_client_container_caches(): - container = await self.read() - self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + await self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.UpsertItem( @@ -904,9 +900,7 @@ async def replace_item( kwargs['no_response'] = no_response request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - if not self.container_link in self.__get_client_container_caches(): - container = await self.read() - self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + await self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.ReplaceItem( @@ -986,9 +980,7 @@ async def patch_item( request_options["partitionKey"] = await self._set_partition_key(partition_key) if filter_predicate is not None: request_options["filterPredicate"] = filter_predicate - if not self.container_link in self.__get_client_container_caches(): - container = await self.read() - self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + await self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) @@ -1056,9 +1048,7 @@ async def delete_item( kwargs['priority'] = priority request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) - if not self.container_link in self.__get_client_container_caches(): - container = await self.read() - self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + await self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] document_link = self._get_document_link(item) @@ -1315,9 +1305,7 @@ async def delete_all_items_by_partition_key( request_options = _build_options(kwargs) # regardless if partition key is valid we set it as invalid partition keys are set to a default empty value request_options["partitionKey"] = await self._set_partition_key(partition_key) - if not self.container_link in self.__get_client_container_caches(): - container = await self.read() - self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + await self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] await self.client_connection.DeleteAllItemsByPartitionKey(collection_link=self.container_link, @@ -1381,9 +1369,7 @@ async def execute_item_batch( request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True - if not self.container_link in self.__get_client_container_caches(): - container = await self.read() - self.__get_client_container_caches()[self.container_link] = _set_properties_cache(container) + await self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return await self.client_connection.Batch( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 5231ed5c06c4..3e6010ac36a7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -52,6 +52,7 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) + print(request) target_container_link = None partition_key = None # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index b8bce3a4d019..ad56e20da02b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -12,7 +12,7 @@ from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _location_cache +from azure.cosmos import PartitionKey, _location_cache, _partition_health_tracker from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError @@ -56,7 +56,7 @@ def write_operations_and_errors(): return params def read_operations_and_errors(): - read_operations = ["read", "query", "changefeed", "read_all_items"] + read_operations = ["read", "query", "query_pk", "changefeed", "changefeed_pk", "changefeed_epk", "read_all_items"] errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -121,19 +121,53 @@ async def perform_read_operation(operation, container, doc_id, pk, expected_uri) request = read_resp.get_response_headers()["_request"] # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_uri) - elif operation == "query": + elif operation == "query_pk": + # partition key filtered query query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): assert item['id'] == doc_id # need to do query with no pk and with feed range + elif operation == "query": + # cross partition query + query = "SELECT * FROM c WHERE c.id = @id" + async for item in container.query_items(query=query): + assert item['id'] == doc_id elif operation == "changefeed": async for _ in container.query_items_change_feed(): pass + elif operation == "changefeed_pk": + # partition key filtered change feed + async for _ in container.query_items_change_feed(partition_key=pk): + pass + elif operation == "changefeed_epk": + # partition key filtered by feed range + feed_range = await container.feed_range_from_partition_key(partition_key=pk) + async for _ in container.query_items_change_feed(feed_range=feed_range): + pass elif operation == "read_all_items": async for item in container.read_all_items(partition_key=pk): assert item['pk'] == pk +def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_consecutive_failure_count < 5 + + assert unhealthy_partitions == expected_unhealthy_partitions + +async def cleanup_method(initialized_objects: List[Dict[str, Any]]): + for obj in initialized_objects: + method_client: CosmosClient = obj["client"] + await method_client.close() @pytest.mark.cosmosMultiRegion @pytest.mark.asyncio @@ -162,12 +196,6 @@ async def setup_method(self, default_endpoint=host, **kwargs): container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) return {"client": client, "db": db, "col": container} - @staticmethod - async def cleanup_method(initialized_objects: List[Dict[str, Any]]): - for obj in initialized_objects: - method_client: CosmosClient = obj["client"] - await method_client.close() - @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) @@ -196,33 +224,37 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w expected_uri) assert exc_info.value == error - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions(global_endpoint_manager, 0) # writes should fail but still be tracked for i in range(4): - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: await perform_write_operation(write_operation, container, fault_injection_container, str(uuid.uuid4()), pk_value, expected_uri) + assert exc_info.value == error + # writes should now succeed because going to the other region await perform_write_operation(write_operation, container, fault_injection_container, str(uuid.uuid4()), pk_value, expected_uri) - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) - await TestPerPartitionCircuitBreakerMMAsync.cleanup_method([custom_setup, setup]) - # test recovering the partition + + validate_unhealthy_partitions(global_endpoint_manager, 1) + await cleanup_method([custom_setup, setup]) + # test recovering the partition --------------------------------------------------------------------- @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) custom_transport = FaultInjectionTransportAsync() id_value = str(uuid.uuid4()) document_definition = {'id': id_value, @@ -230,7 +262,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re 'name': 'sample document', 'key': 'value'} predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error @@ -242,11 +274,8 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re container = setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - # create some documents + # create a document to read await container.create_item(body=document_definition) - for i in range(5): - document_definition['id'] = str(uuid.uuid4()) - await container.create_item(body=document_definition) # reads should fail over and only the relevant partition should be marked as unavailable await perform_read_operation(read_operation, @@ -255,7 +284,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re document_definition['pk'], expected_uri) # partition should not have been marked unavailable after one error - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(10): await perform_read_operation(read_operation, @@ -265,24 +294,22 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re expected_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions(global_endpoint_manager, 1) + await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_failure_rate_threshold_async(self, setup_teardown, write_operation, error): + async def test_write_failure_rate_threshold_async(self, setup_teardown, write_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) - custom_transport = await self.create_custom_transport_sm_mrr() - id_value = 'failoverDoc-' + str(uuid.uuid4()) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransportAsync() # two documents targeted to same partition, one will always fail and the other will succeed - document_definition = {'id': id_value, - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} - doc_2 = {'id': str(uuid.uuid4()), - 'pk': 'pk1', + pk_value = "pk1" + doc = {'id': str(uuid.uuid4()), + 'pk': pk_value, 'name': 'sample document', 'key': 'value'} - predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and - FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, @@ -290,49 +317,58 @@ async def test_failure_rate_threshold_async(self, setup_teardown, write_operatio ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - container = custom_setup['col'] - global_endpoint_manager = container.client_connection._global_endpoint_manager + fault_injection_container = custom_setup['col'] + setup = await self.setup_method(default_endpoint=self.host) + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing - global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" try: - # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition - for i in range(14): - if i == 9: - await container.upsert_item(body=doc_2) - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup_teardown[COLLECTION], - container, - document_definition['id'], - document_definition['pk'], - expected_uri) - - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - - # create item with client without fault injection - await setup_teardown[COLLECTION].create_item(body=document_definition) - finally: - # restore minimum requests - global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + await fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_read_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + pk_value, + expected_uri) + assert exc_info.value == error + validate_unhealthy_partitions(global_endpoint_manager, 1) + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) - @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_read_failure_rate_threshold_async(self, setup_teardown, write_operation, error): + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_failure_rate_threshold_async(self, setup_teardown, read_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) - custom_transport = await self.create_custom_transport_sm_mrr() - id_value = 'failoverDoc-' + str(uuid.uuid4()) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransportAsync() # two documents targeted to same partition, one will always fail and the other will succeed - document_definition = {'id': id_value, - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} - doc_2 = {'id': str(uuid.uuid4()), - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} - predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and - FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + pk_value = "pk1" + doc = {'id': str(uuid.uuid4()), + 'pk': pk_value, + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, @@ -340,69 +376,44 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, write_ope ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - container = custom_setup['col'] - global_endpoint_manager = container.client_connection._global_endpoint_manager + fault_injection_container = custom_setup['col'] + setup = await self.setup_method(default_endpoint=self.host) + container = setup['col'] + await container.upsert_item(body=doc) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing - global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" try: - # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition - for i in range(14): - if i == 9: - await container.upsert_item(body=doc_2) - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, - setup_teardown[COLLECTION], - container, - document_definition['id'], - document_definition['pk'], - expected_uri) - - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - - # create item with client without fault injection - await setup_teardown[COLLECTION].create_item(body=document_definition) - - # reads should fail over and only the relevant partition should be marked as unavailable - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) - # partition should not have been marked unavailable after one error - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - for i in range(20): - if i == 8: - read_resp = await container.read_item(item=doc_2['id'], - partition_key=doc_2['pk']) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_uri) + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4: + # perform some successful read to reset consecutive counter + # remove faults and perform a read + custom_transport.faults = [] + await fault_injection_container.read_item(item=doc["id"], partition_key=pk_value) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) else: - await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, - container, - document_definition['id'], - document_definition['pk'], - expected_uri) - # the partition should have been marked as unavailable after breaking read threshold - TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + # read will fail and retry in other region + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + pk_value, + expected_uri) + + validate_unhealthy_partitions(global_endpoint_manager, 1) finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" # restore minimum requests - global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) # look at the urls for verifying fall back and use another id for same partition - @staticmethod - def validate_unhealthy_partitions(global_endpoint_manager, - expected_unhealthy_partitions): - health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info - unhealthy_partitions = 0 - for pk_range_wrapper, location_to_health_info in health_info_map.items(): - for location, health_info in location_to_health_info.items(): - health_status = health_info.unavailability_info.get(HEALTH_STATUS) - if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: - unhealthy_partitions += 1 - else: - assert health_info.read_consecutive_failure_count < 10 - assert health_info.write_consecutive_failure_count < 5 - - assert unhealthy_partitions == expected_unhealthy_partitions - # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout From 8ee6b3e429b6b2998f5f32b52f74c082a05cf5de Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Fri, 18 Apr 2025 23:23:30 -0700 Subject: [PATCH 091/152] recovering optimizations, lower request timeout, disable in region retries --- ...tition_endpoint_manager_circuit_breaker.py | 1 - ...n_endpoint_manager_circuit_breaker_core.py | 5 ++-- .../azure/cosmos/_partition_health_tracker.py | 14 +++++++--- .../cosmos/_routing/routing_map_provider.py | 1 + .../azure/cosmos/aio/_asynchronous_request.py | 2 ++ .../aio/_global_endpoint_manager_async.py | 2 +- ..._endpoint_manager_circuit_breaker_async.py | 10 +++++-- .../azure/cosmos/aio/_retry_utility_async.py | 6 +++-- .../azure-cosmos/azure/cosmos/documents.py | 4 +++ .../azure-cosmos/tests/test_ppcb_mm_async.py | 26 +++++++++++++++---- 10 files changed, 54 insertions(+), 17 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 288205cb2411..f5a0f780be0a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -62,7 +62,6 @@ def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWr target_container_link = container_link if not target_container_link: raise RuntimeError("Illegal state: the container cache is not properly initialized.") - # TODO: @tvaron3 check different clients and create them in different ways pk_range = (self.Client._routing_map_provider # pylint: disable=protected-access .get_overlapping_ranges(target_container_link, partition_key)) return PartitionKeyRangeWrapper(pk_range, container_rid) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 577b8410f435..ac70cc873dc3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -80,7 +80,6 @@ def record_failure( else EndpointOperationType.ReadType) location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) self.partition_health_tracker.add_failure(pk_range_wrapper, endpoint_operation_type, str(location)) - # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering # TODO: @tvaron3 exponential backoff for recovering def add_excluded_locations_to_request( @@ -108,4 +107,6 @@ def record_success( location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) self.partition_health_tracker.add_success(pk_range_wrapper, endpoint_operation_type, location) -# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- + def is_healthy_tentative(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + return self.partition_health_tracker.is_healthy_tentative(pk_range_wrapper, location) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 9760a0540387..ef773b454522 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -23,7 +23,7 @@ """ import logging import os -from typing import Dict, Set, Any, List +from typing import Dict, Any, List from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import current_time_millis, EndpointOperationType from ._constants import _Constants as Constants @@ -48,11 +48,9 @@ def _has_exceeded_failure_rate_threshold( failures: int, failure_rate_threshold: int, ) -> bool: - print(MINIMUM_REQUESTS_FOR_FAILURE_RATE) if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: return False failure_rate = failures / (failures + successes) * 100 - print("Failure rate", failure_rate) return failure_rate >= failure_rate_threshold class _PartitionHealthInfo(object): @@ -248,7 +246,6 @@ def _check_thresholds( failure_rate_threshold: int, consecutive_failure_threshold: int, ) -> None: - print("Check Thresholds called") # check the failure rate was not exceeded if _has_exceeded_failure_rate_threshold( successes, @@ -286,3 +283,12 @@ def _reset_partition_health_tracker_stats(self) -> None: for locations in self.pk_range_wrapper_to_health_info.values(): for health_info in locations.values(): health_info.reset_health_stats() + + def is_healthy_tentative(self, pk_range_wrapper, location): + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + if location in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] + return health_info.unavailability_info and \ + health_info.unavailability_info[HEALTH_STATUS] == HEALTHY_TENTATIVE + return False + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 8dacb5190e07..21e1e09b78cc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -63,6 +63,7 @@ def get_overlapping_ranges(self, collection_link, partition_key_ranges, **kwargs collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + # TODO: @tvaron3 change this to be by collectionRID collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: collection_pk_ranges = list(cl._ReadPartitionKeyRanges(collection_link, **kwargs)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 4fda37ea0a87..71d74f14a1a9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -59,6 +59,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p # Every request tries to perform a refresh client_timeout = kwargs.get('timeout') start_time = time.time() + if global_endpoint_manager.is_healthy_tentative(request_params): + read_timeout = connection_policy.RecoveryReadTimeout if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: await global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 0fe666f1983c..b543e42cedfd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -215,5 +215,5 @@ async def close(self): self.refresh_task.cancel() try: await self.refresh_task - except (Exception, asyncio.CancelledError) : #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) : pass diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 3e6010ac36a7..9cc2a07472b4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -94,7 +94,7 @@ async def record_failure( self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): - if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) and pk_range_wrapper: + if self.is_circuit_breaker_applicable(request): request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) return (super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self) @@ -111,6 +111,12 @@ async def record_success( self, request: RequestObject ) -> None: - if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + if self.is_circuit_breaker_applicable(request): pk_range_wrapper = await self.create_pk_range_wrapper(request) self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) + + async def is_healthy_tentative(self, request: RequestObject) -> bool: + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + return self.global_partition_endpoint_manager_core.is_healthy_tentative(request, pk_range_wrapper) + return False diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index a9cdff17703a..89787e8a21e3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -288,7 +288,8 @@ async def send(self, request): retry_error = err # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies - if not _has_database_account_header(request.http_request.headers): + if (not _has_database_account_header(request.http_request.headers) + and not global_endpoint_manager.is_healthy_tentative(request_params)): await global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -298,7 +299,8 @@ async def send(self, request): raise err except ServiceResponseError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + global_endpoint_manager.is_healthy_tentative(request_params)): raise err # Since this is ClientConnectionError, it is safe to be retried on both read and write requests try: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 7ccc99da9dfe..784efe91cd05 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -339,6 +339,7 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes __defaultRequestTimeout: int = 5 # seconds __defaultDBAConnectionTimeout: int = 3 # seconds __defaultReadTimeout: int = 65 # seconds + __defaultRecoveryReadTimeout: int = 6 # seconds __defaultDBAReadTimeout: int = 3 # seconds __defaultMaxBackoff: int = 1 # seconds @@ -347,6 +348,9 @@ def __init__(self) -> None: self.RequestTimeout: int = self.__defaultRequestTimeout self.DBAConnectionTimeout: int = self.__defaultDBAConnectionTimeout self.ReadTimeout: int = self.__defaultReadTimeout + # The request timeout for a request trying to recover a unavailable partition + # This is only applicable if circuit breaker is enabled + self.RecoveryReadTimeout: int = self.__defaultRecoveryReadTimeout self.DBAReadTimeout: int = self.__defaultDBAReadTimeout self.MaxBackoff: int = self.__defaultMaxBackoff self.ConnectionMode: int = ConnectionMode.Gateway diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index ad56e20da02b..91fa5ffd6dd1 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -246,8 +246,20 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w expected_uri) validate_unhealthy_partitions(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_val = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + pk_value, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_val await cleanup_method([custom_setup, setup]) - # test recovering the partition --------------------------------------------------------------------- @@ -340,7 +352,7 @@ async def test_write_failure_rate_threshold_async(self, setup_teardown, write_op ))) else: with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: - await perform_read_operation(write_operation, + await perform_write_operation(write_operation, container, fault_injection_container, str(uuid.uuid4()), @@ -382,12 +394,16 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper await container.upsert_item(body=doc) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing - _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" try: - for i in range(10): + if isinstance(error, ServiceResponseError): + num_operations = 4 + else: + num_operations = 8 + for i in range(num_operations): validate_unhealthy_partitions(global_endpoint_manager, 0) - if i == 4: + if i == 2: # perform some successful read to reset consecutive counter # remove faults and perform a read custom_transport.faults = [] From b80b20c9ba15e4439012f6194ce05da342bc61d3 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Sun, 20 Apr 2025 16:32:13 -0700 Subject: [PATCH 092/152] recovering optimizations, lower request timeout, disable in region retries --- .../azure/cosmos/_location_cache.py | 3 +- .../azure/cosmos/_partition_health_tracker.py | 1 + ..._endpoint_manager_circuit_breaker_async.py | 1 - .../azure/cosmos/aio/_retry_utility_async.py | 8 +-- .../azure-cosmos/tests/test_ppcb_mm_async.py | 52 ++++++++++--------- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 089e8af7b29f..ab70259cfb4e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -205,7 +205,8 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> List[str excluded_locations = request.excluded_locations if excluded_locations is None: # If excluded locations were only configured on client(connection_policy), use client level - excluded_locations = self.connection_policy.ExcludedLocations + # make copy of excluded locations to avoid modifying the original list + excluded_locations = list(self.connection_policy.ExcludedLocations) excluded_locations.extend(request.excluded_locations_circuit_breaker) return excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index ef773b454522..2eba88939e81 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -30,6 +30,7 @@ MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 +MAX_UNAVAILABLE_TIME = 1800 * 1000 # milliseconds REFRESH_INTERVAL = 60 * 1000 # milliseconds INITIAL_UNAVAILABLE_TIME = 60 * 1000 # milliseconds # partition is unhealthy if sdk tried to recover and failed diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 9cc2a07472b4..a5e8e0692400 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -52,7 +52,6 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) - print(request) target_container_link = None partition_key = None # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 89787e8a21e3..85f7f0e790d8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -289,9 +289,9 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if (not _has_database_account_header(request.http_request.headers) - and not global_endpoint_manager.is_healthy_tentative(request_params)): - await global_endpoint_manager.record_failure(request_params) + and not await global_endpoint_manager.is_healthy_tentative(request_params)): if retry_settings['connect'] > 0: + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) @@ -300,7 +300,7 @@ async def send(self, request): except ServiceResponseError as err: retry_error = err if (_has_database_account_header(request.http_request.headers) or - global_endpoint_manager.is_healthy_tentative(request_params)): + await global_endpoint_manager.is_healthy_tentative(request_params)): raise err # Since this is ClientConnectionError, it is safe to be retried on both read and write requests try: @@ -309,9 +309,9 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): - await global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 91fa5ffd6dd1..652a93a07842 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -33,7 +33,7 @@ async def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - await asyncio.sleep(2) + await asyncio.sleep(3) yield await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() @@ -247,7 +247,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w validate_unhealthy_partitions(global_endpoint_manager, 1) # remove faults and reduce initial recover time and perform a write - original_val = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 custom_transport.faults = [] try: @@ -258,11 +258,9 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w pk_value, uri_down) finally: - _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_val + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time await cleanup_method([custom_setup, setup]) - - @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) @@ -307,6 +305,18 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re # the partition should have been marked as unavailable after breaking read threshold validate_unhealthy_partitions(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a read + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_read_operation(read_operation, + fault_injection_container, + document_definition['id'], + document_definition['pk'], + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) @@ -398,28 +408,18 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" try: if isinstance(error, ServiceResponseError): - num_operations = 4 + # service response error retries in region 3 additional times before failing over + num_operations = 2 else: num_operations = 8 for i in range(num_operations): validate_unhealthy_partitions(global_endpoint_manager, 0) - if i == 2: - # perform some successful read to reset consecutive counter - # remove faults and perform a read - custom_transport.faults = [] - await fault_injection_container.read_item(item=doc["id"], partition_key=pk_value) - custom_transport.add_fault(predicate, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 0, - error - ))) - else: - # read will fail and retry in other region - await perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - pk_value, - expected_uri) + # read will fail and retry in other region + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + pk_value, + expected_uri) validate_unhealthy_partitions(global_endpoint_manager, 1) finally: @@ -428,9 +428,11 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 await cleanup_method([custom_setup, setup]) - # look at the urls for verifying fall back and use another id for same partition + async def test_service_request_error_async(self): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + pass - # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout From e3ab320d83f9fd8d7617b4a8540b554bd51b242f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 21 Apr 2025 01:00:29 -0700 Subject: [PATCH 093/152] fix transitions from a success --- ...tition_endpoint_manager_circuit_breaker.py | 1 - ...n_endpoint_manager_circuit_breaker_core.py | 13 +- .../azure/cosmos/_partition_health_tracker.py | 66 +++--- .../azure/cosmos/_request_object.py | 1 + .../_routing/aio/routing_map_provider.py | 10 +- .../cosmos/_routing/routing_map_provider.py | 1 - .../azure/cosmos/_routing/routing_range.py | 4 +- .../azure/cosmos/aio/_asynchronous_request.py | 2 +- .../aio/_cosmos_client_connection_async.py | 2 +- .../aio/_global_endpoint_manager_async.py | 2 +- ..._endpoint_manager_circuit_breaker_async.py | 9 +- .../azure/cosmos/aio/_retry_utility_async.py | 5 +- .../routing/test_collection_routing_map.py | 13 ++ .../azure-cosmos/tests/test_ppcb_mm_async.py | 189 ++++++++++++++++-- 14 files changed, 236 insertions(+), 82 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index f5a0f780be0a..608f63c06ecf 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -75,7 +75,6 @@ def record_failure( self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> str: - # TODO: @tvaron3 check here if it is healthy tentative and move it back to Unhealthy if self.is_circuit_breaker_applicable(request): request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index ac70cc873dc3..457fae94a2ec 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -82,13 +82,21 @@ def record_failure( self.partition_health_tracker.add_failure(pk_range_wrapper, endpoint_operation_type, str(location)) # TODO: @tvaron3 exponential backoff for recovering + def check_stale_partition_info( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.partition_health_tracker._check_stale_partition_info(request, pk_range_wrapper) + + def add_excluded_locations_to_request( self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper ) -> RequestObject: request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(pk_range_wrapper) + self.partition_health_tracker.get_excluded_locations(request, pk_range_wrapper) ) return request @@ -107,6 +115,3 @@ def record_success( location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) self.partition_health_tracker.add_success(pk_range_wrapper, endpoint_operation_type, location) - def is_healthy_tentative(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): - location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - return self.partition_health_tracker.is_healthy_tentative(pk_range_wrapper, location) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 2eba88939e81..f7f270f47897 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -22,10 +22,12 @@ """Internal class for partition health tracker for circuit breaker. """ import logging +import threading import os from typing import Dict, Any, List from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import current_time_millis, EndpointOperationType +from azure.cosmos._request_object import RequestObject from ._constants import _Constants as Constants @@ -37,8 +39,6 @@ UNHEALTHY = "unhealthy" # partition is unhealthy tentative when it initially marked unavailable UNHEALTHY_TENTATIVE = "unhealthy_tentative" -# partition is healthy tentative when sdk is trying to recover -HEALTHY_TENTATIVE = "healthy_tentative" # unavailability info keys LAST_UNAVAILABILITY_CHECK_TIME_STAMP = "lastUnavailabilityCheckTimeStamp" HEALTH_STATUS = "healthStatus" @@ -86,6 +86,14 @@ def __str__(self) -> str: f"write consecutive failure count: {self.write_consecutive_failure_count}\n" f"read consecutive failure count: {self.read_consecutive_failure_count}\n") +def _should_mark_healthy_tentative(partition_health_info: _PartitionHealthInfo, curr_time: int) -> bool: + elapsed_time = (current_time - + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) + current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + # check if the partition key range is still unavailable + return ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) + or (current_health_status == UNHEALTHY_TENTATIVE and elapsed_time > INITIAL_UNAVAILABLE_TIME)) + logger = logging.getLogger("azure.cosmos._PartitionHealthTracker") class _PartitionHealthTracker(object): @@ -98,6 +106,7 @@ def __init__(self) -> None: # partition -> regions -> health info self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} self.last_refresh = current_time_millis() + self.stale_partiton_lock = threading.Lock() def mark_partition_unavailable(self, pk_range_wrapper: PartitionKeyRangeWrapper, location: str) -> None: # mark the partition key range as unavailable @@ -146,39 +155,46 @@ def _transition_health_status_on_success( ) -> None: if pk_range_wrapper in self.pk_range_wrapper_to_health_info: # healthy tentative -> healthy - self.pk_range_wrapper_to_health_info[pk_range_wrapper].pop(location, None) + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location].unavailability_info = {} - def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: + def _check_stale_partition_info( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: current_time = current_time_millis() stale_partition_unavailability_check = int(os.environ.get(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 if pk_range_wrapper in self.pk_range_wrapper_to_health_info: - for _, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): if partition_health_info.unavailability_info: - elapsed_time = (current_time - - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) - current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] - # check if the partition key range is still unavailable - if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) - or (current_health_status == UNHEALTHY_TENTATIVE - and elapsed_time > INITIAL_UNAVAILABLE_TIME)): + if self._should_mark_healthy_tentative(partition_health_info, current_time): # unhealthy or unhealthy tentative -> healthy tentative - partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + with (self.stale_partiton_lock): + if self._should_mark_healthy_tentative(partition_health_info, current_time): + # this will trigger one attempt to recover + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] = current_time + partition_health_info.unavailability_info[HEALTH_STATUS] = UNHEALTHY + request.healthy_tentative_location = location if current_time - self.last_refresh > REFRESH_INTERVAL: # all partition stats reset every minute self._reset_partition_health_tracker_stats() - def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> List[str]: - self._check_stale_partition_info(pk_range_wrapper) + def get_excluded_locations( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> List[str]: excluded_locations = [] if pk_range_wrapper in self.pk_range_wrapper_to_health_info: for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): - if partition_health_info.unavailability_info: + if (partition_health_info.unavailability_info and + not (request.healthy_tentative_location and request.healthy_tentative_location == location)): health_status = partition_health_info.unavailability_info[HEALTH_STATUS] - if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY): + if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY) : excluded_locations.append(location) return excluded_locations @@ -200,7 +216,6 @@ def add_failure( self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] - print(failure_rate_threshold) # Determine attribute names and environment variables based on the operation type. if operation_type == EndpointOperationType.WriteType: @@ -233,9 +248,9 @@ def add_failure( failure_rate_threshold, consecutive_failure_threshold ) - print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) print(pk_range_wrapper) print(location) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) def _check_thresholds( self, @@ -274,22 +289,13 @@ def add_success(self, pk_range_wrapper: PartitionKeyRangeWrapper, operation_type else: health_info.read_success_count += 1 health_info.read_consecutive_failure_count = 0 - self._transition_health_status_on_success(pk_range_wrapper, operation_type) - print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) print(pk_range_wrapper) print(location) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) + self._transition_health_status_on_success(pk_range_wrapper, location) def _reset_partition_health_tracker_stats(self) -> None: for locations in self.pk_range_wrapper_to_health_info.values(): for health_info in locations.values(): health_info.reset_health_stats() - - def is_healthy_tentative(self, pk_range_wrapper, location): - if pk_range_wrapper in self.pk_range_wrapper_to_health_info: - if location in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: - health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] - return health_info.unavailability_info and \ - health_info.unavailability_info[HEALTH_STATUS] == HEALTHY_TENTATIVE - return False - diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 24dc9e0dd9c2..8ba07bd0b29e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -43,6 +43,7 @@ def __init__( self.last_routed_location_endpoint_within_region: Optional[str] = None self.excluded_locations: Optional[List[str]] = None self.excluded_locations_circuit_breaker: List[str] = [] + self.healthy_tentative_location: Optional[str] = None def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index f59513d05d24..951f56a89400 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -82,7 +82,7 @@ async def initialize_collection_routing_map_if_needed( collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( [(r, True) for r in collection_pk_ranges], collection_id ) - self._collection_routing_map_by_item[collection_id] = collection_routing_map + self._collection_routing_map_by_item[collection_id] = collection_routing_map async def get_range_by_partition_key_range_id( self, @@ -216,11 +216,3 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** pass return target_partition_key_ranges - - async def get_range_by_partition_key_range_id( - self, - collection_link: str, - partition_key_range_id: int, - **kwargs: Dict[str, Any] - ) -> Dict[str, Any]: - return await super().get_range_by_partition_key_range_id(collection_link, partition_key_range_id, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 21e1e09b78cc..8dacb5190e07 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -63,7 +63,6 @@ def get_overlapping_ranges(self, collection_link, partition_key_ranges, **kwargs collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - # TODO: @tvaron3 change this to be by collectionRID collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: collection_pk_ranges = list(cl._ReadPartitionKeyRanges(collection_link, **kwargs)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index e31682725828..36053a4ab4d5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -100,10 +100,10 @@ def to_normalized_range(self): normalized_max = self.max if not self.isMinInclusive: - normalized_min = self.add_to_effective_partition_key(self.min, -1) + normalized_min = self.add_to_effective_partition_key(self.min, -1).upper() if self.isMaxInclusive: - normalized_max = self.add_to_effective_partition_key(self.max, 1) + normalized_max = self.add_to_effective_partition_key(self.max, 1).upper() return Range(normalized_min, normalized_max, True, False) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 71d74f14a1a9..53e87c1d0211 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -59,7 +59,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p # Every request tries to perform a refresh client_timeout = kwargs.get('timeout') start_time = time.time() - if global_endpoint_manager.is_healthy_tentative(request_params): + if request_params.healthy_tentative_location: read_timeout = connection_policy.RecoveryReadTimeout if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: await global_endpoint_manager.refresh_endpoint_list(None, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 4aa2386ca344..5df1bf70d44e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2859,7 +2859,7 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, :raises SystemError: If the query compatibility mode is undefined. """ if options is None: - options: Dict[str, Any] = {} + options = {} if query: __GetBodiesFromQueryResult = result_fn diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 411c556b93bd..6a61cacc5fc1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -103,7 +103,7 @@ async def refresh_endpoint_list(self, database_account, **kwargs): try: await self.refresh_task self.refresh_task = None - except (Exception, asyncio.CancelledError) as exception: #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) as exception: logger.exception("Health check task failed: %s", exception) if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index a5e8e0692400..3f6f4c333808 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -68,7 +68,6 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyR if request.headers.get(HttpHeaders.PartitionKey): partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key - # TODO: @tvaron3 check different clients and create them in different ways epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] partition_ranges = await (self.client._routing_map_provider .get_overlapping_ranges(target_container_link, epk_range)) @@ -94,8 +93,10 @@ async def record_failure( def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): if self.is_circuit_breaker_applicable(request): + self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) + # Todo: @tvaron3 think about how to switch healthy tentative unhealthy tentative return (super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self) .resolve_service_endpoint(request, pk_range_wrapper)) @@ -113,9 +114,3 @@ async def record_success( if self.is_circuit_breaker_applicable(request): pk_range_wrapper = await self.create_pk_range_wrapper(request) self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) - - async def is_healthy_tentative(self, request: RequestObject) -> bool: - if self.is_circuit_breaker_applicable(request): - pk_range_wrapper = await self.create_pk_range_wrapper(request) - return self.global_partition_endpoint_manager_core.is_healthy_tentative(request, pk_range_wrapper) - return False diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 85f7f0e790d8..71e78dd1df86 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -289,7 +289,7 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if (not _has_database_account_header(request.http_request.headers) - and not await global_endpoint_manager.is_healthy_tentative(request_params)): + and not request_params.healthy_tentative_location): if retry_settings['connect'] > 0: await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) @@ -299,8 +299,7 @@ async def send(self, request): raise err except ServiceResponseError as err: retry_error = err - if (_has_database_account_header(request.http_request.headers) or - await global_endpoint_manager.is_healthy_tentative(request_params)): + if (_has_database_account_header(request.http_request.headers) or request_params.healthy_tentative_location): raise err # Since this is ClientConnectionError, it is safe to be retried on both read and write requests try: diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py b/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py index 10b5819310a9..69468423a111 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py @@ -28,6 +28,19 @@ def test_advanced(self): self.assertEqual(len(overlapping_partition_key_ranges), len(partition_key_ranges)) self.assertEqual(overlapping_partition_key_ranges, partition_key_ranges) + # def test_lowercased_range(self): + # partition_key_ranges = [{u'id': u'0', u'minInclusive': u'', u'maxExclusive': u'1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF'}, + # {u'id': u'1', u'minInclusive': u'1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF', u'maxExclusive': u'FF'}] + # partitionRangeWithInfo = [(r, True) for r in partition_key_ranges] + # expected_partition_key_range = routing_range.Range("", "1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", True, False) + # + # pkRange = routing_range.Range("1EC0C2CBE45DBC919CF2B65D399C2673", "1ec0c2cbe45dbc919cf2b65d399c2674", True, True) + # collection_routing_map = CollectionRoutingMap.CompleteRoutingMap(partitionRangeWithInfo, 'sample collection id') + # overlapping_partition_key_ranges = collection_routing_map.get_overlapping_ranges(pkRange) + # + # self.assertEqual(len(overlapping_partition_key_ranges), 1) + # self.assertEqual(expected_partition_key_range, overlapping_partition_key_ranges[0]) + def test_partition_key_ranges_parent_filter(self): # for large collection with thousands of partitions, a split may complete between the read partition key ranges query pages, # causing the return map to have both the new children ranges and their ranges. This test is to verify the fix for that. diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 652a93a07842..29c4ecfdcac4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -20,6 +20,20 @@ REGION_1 = "West US 3" REGION_2 = "Mexico Central" # "West US" +CHANGE_FEED = "changefeed" +CHANGE_FEED_PK = "changefeed_pk" +CHANGE_FEED_EPK = "changefeed_epk" +READ = "read" +CREATE = "create" +READ_ALL_ITEMS = "read_all_items" +DELETE_ALL_ITEMS_BY_PARTITION_KEY = "delete_all_items_by_partition_key" +QUERY = "query" +QUERY_PK = "query_pk" +BATCH = "batch" +UPSERT = "upsert" +REPLACE = "replace" +PATCH = "patch" +DELETE = "delete" COLLECTION = "created_collection" @@ -40,7 +54,7 @@ async def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" def write_operations_and_errors(): - write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] # "delete_all_items_by_partition_key"] + write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH]# "delete_all_items_by_partition_key"] errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -56,7 +70,7 @@ def write_operations_and_errors(): return params def read_operations_and_errors(): - read_operations = ["read", "query", "query_pk", "changefeed", "changefeed_pk", "changefeed_epk", "read_all_items"] + read_operations = [READ, QUERY, QUERY_PK, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS] errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -71,6 +85,15 @@ def read_operations_and_errors(): return params +def operations(): + write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH] + read_operations = [READ, QUERY_PK, CHANGE_FEED_PK, CHANGE_FEED_EPK] + operations = [] + for i, write_operation in enumerate(write_operations): + operations.append((read_operations[i % len(read_operations)], write_operation)) + + return operations + def validate_response_uri(response, expected_uri): request = response.get_response_headers()["_request"] assert request.url.startswith(expected_uri) @@ -80,25 +103,25 @@ async def perform_write_operation(operation, container, fault_injection_containe 'pk': pk, 'name': 'sample document', 'key': 'value'} - if operation == "create": + if operation == CREATE: resp = await fault_injection_container.create_item(body=doc) - elif operation == "upsert": + elif operation == UPSERT: resp = await fault_injection_container.upsert_item(body=doc) - elif operation == "replace": + elif operation == REPLACE: await container.create_item(body=doc) new_doc = {'id': doc_id, 'pk': pk, 'name': 'sample document' + str(uuid), 'key': 'value'} resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) - elif operation == "delete": + elif operation == DELETE: await container.create_item(body=doc) resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) - elif operation == "patch": + elif operation == PATCH: await container.create_item(body=doc) operations = [{"op": "incr", "path": "/company", "value": 3}] resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) - elif operation == "batch": + elif operation == BATCH: batch_operations = [ ("create", (doc, )), ("upsert", (doc,)), @@ -116,38 +139,38 @@ async def perform_write_operation(operation, container, fault_injection_containe validate_response_uri(resp, expected_uri) async def perform_read_operation(operation, container, doc_id, pk, expected_uri): - if operation == "read": + if operation == READ: read_resp = await container.read_item(item=doc_id, partition_key=pk) request = read_resp.get_response_headers()["_request"] # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_uri) - elif operation == "query_pk": + elif operation == QUERY_PK: # partition key filtered query query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): assert item['id'] == doc_id # need to do query with no pk and with feed range - elif operation == "query": + elif operation == QUERY: # cross partition query query = "SELECT * FROM c WHERE c.id = @id" async for item in container.query_items(query=query): assert item['id'] == doc_id - elif operation == "changefeed": + elif operation == CHANGE_FEED: async for _ in container.query_items_change_feed(): pass - elif operation == "changefeed_pk": + elif operation == CHANGE_FEED_PK: # partition key filtered change feed async for _ in container.query_items_change_feed(partition_key=pk): pass - elif operation == "changefeed_epk": + elif operation == CHANGE_FEED_EPK: # partition key filtered by feed range feed_range = await container.feed_range_from_partition_key(partition_key=pk) async for _ in container.query_items_change_feed(feed_range=feed_range): pass - elif operation == "read_all_items": - async for item in container.read_all_items(partition_key=pk): - assert item['pk'] == pk + elif operation == READ_ALL_ITEMS: + async for _ in container.read_all_items(): + pass def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): @@ -259,6 +282,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w uri_down) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) @@ -304,7 +328,12 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re expected_uri) # the partition should have been marked as unavailable after breaking read threshold - validate_unhealthy_partitions(global_endpoint_manager, 1) + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) # remove faults and reduce initial recover time and perform a read original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 @@ -317,6 +346,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re uri_down) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) @@ -420,20 +450,135 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper doc['id'], pk_value, expected_uri) + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 - validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) finally: os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 await cleanup_method([custom_setup, setup]) - async def test_service_request_error_async(self): + @pytest.mark.parametrize("read_operation, write_operation", operations()) + async def test_service_request_error_async(self, read_operation, write_operation): # the region should be tried 4 times before failing over and mark the partition as unavailable # the region should not be marked as unavailable - pass + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransportAsync() + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + pk_value = "pk1" + doc = {'id': str(uuid.uuid4()), + 'pk': pk_value, + 'name': 'sample document', + 'key': 'value'} + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + fault_injection_container = custom_setup['col'] + setup = await self.setup_method(default_endpoint=self.host) + container = setup['col'] + await container.upsert_item(body=doc) + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + pk_value, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + + # recover partition + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + pk_value, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + pk_value, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + + + await cleanup_method([custom_setup, setup]) + + + # send 5 write concurrent requests when trying to recover + # verify that only one failed + async def test_recovering_only_fails_one_requests_async(self): + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransportAsync() + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, CosmosHttpResponseError( + status_code=502, + message="Some envoy error.")))) + pk_value = "pk1" + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + fault_injection_container = custom_setup['col'] + doc = {'id': str(uuid.uuid4()), + 'pk': pk_value, + 'name': 'sample document', + 'key': 'value'} + + for i in range(5): + with pytest.raises(CosmosHttpResponseError): + await fault_injection_container.create_item(body=doc) + + + number_of_errors = 0 + + async def concurrent_upsert(): + nonlocal number_of_errors + doc = {'id': str(uuid.uuid4()), + 'pk': pk_value, + 'name': 'sample document', + 'key': 'value'} + try: + await fault_injection_container.upsert_item(doc) + except CosmosHttpResponseError as e: + number_of_errors += 1 + + # attempt to recover partition + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + try: + tasks = [] + for i in range(10): + tasks.append(concurrent_upsert()) + await asyncio.gather(*tasks) + assert number_of_errors == 1 + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + await cleanup_method([custom_setup]) - # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout if __name__ == '__main__': From 93071579e41ea742917080f5493e5fc82623b668 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 21 Apr 2025 10:35:55 -0700 Subject: [PATCH 094/152] Implement exponential backoff --- .../azure-cosmos/azure/cosmos/_constants.py | 2 - .../azure/cosmos/_partition_health_tracker.py | 70 +++++++++++-------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index d0f5a3d185ad..d0e0f54ae04c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -60,8 +60,6 @@ class _Constants: CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT: int = 5 FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED" FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 90 - STALE_PARTITION_UNAVAILABILITY_CHECK = "AZURE_COSMOS_STALE_PARTITION_UNAVAILABILITY_CHECK_IN_SECONDS" - STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT: int = 120 # ------------------------------------------------------------------------- # Error code translations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index f7f270f47897..aa2ffddee5f6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -32,7 +32,7 @@ MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 -MAX_UNAVAILABLE_TIME = 1800 * 1000 # milliseconds +MAX_UNAVAILABLE_TIME = 1200 * 1000 # milliseconds REFRESH_INTERVAL = 60 * 1000 # milliseconds INITIAL_UNAVAILABLE_TIME = 60 * 1000 # milliseconds # partition is unhealthy if sdk tried to recover and failed @@ -40,20 +40,10 @@ # partition is unhealthy tentative when it initially marked unavailable UNHEALTHY_TENTATIVE = "unhealthy_tentative" # unavailability info keys +UNAVAILABLE_INTERVAL = "unavailableInterval" LAST_UNAVAILABILITY_CHECK_TIME_STAMP = "lastUnavailabilityCheckTimeStamp" HEALTH_STATUS = "healthStatus" - -def _has_exceeded_failure_rate_threshold( - successes: int, - failures: int, - failure_rate_threshold: int, -) -> bool: - if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: - return False - failure_rate = failures / (failures + successes) * 100 - return failure_rate >= failure_rate_threshold - class _PartitionHealthInfo(object): """ This internal class keeps the health and statistics for a partition. @@ -77,6 +67,22 @@ def reset_health_stats(self) -> None: self.read_consecutive_failure_count = 0 self.write_consecutive_failure_count = 0 + def transition_health_status(self, target_health_status: str, curr_time: int) -> None: + if target_health_status == UNHEALTHY : + self.unavailability_info[HEALTH_STATUS] = UNHEALTHY + # reset the last unavailability check time stamp + self.unavailability_info[UNAVAILABLE_INTERVAL] = \ + min(self.unavailability_info[UNAVAILABLE_INTERVAL] * 2, + MAX_UNAVAILABLE_TIME) + self.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ + = curr_time + elif target_health_status == UNHEALTHY_TENTATIVE : + self.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: curr_time, + UNAVAILABLE_INTERVAL: INITIAL_UNAVAILABLE_TIME, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + def __str__(self) -> str: return (f"{self.__class__.__name__}: {self.unavailability_info}\n" f"write failure count: {self.write_failure_count}\n" @@ -86,10 +92,21 @@ def __str__(self) -> str: f"write consecutive failure count: {self.write_consecutive_failure_count}\n" f"read consecutive failure count: {self.read_consecutive_failure_count}\n") +def _has_exceeded_failure_rate_threshold( + successes: int, + failures: int, + failure_rate_threshold: int, +) -> bool: + if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: + return False + failure_rate = failures / (failures + successes) * 100 + return failure_rate >= failure_rate_threshold + def _should_mark_healthy_tentative(partition_health_info: _PartitionHealthInfo, curr_time: int) -> bool: - elapsed_time = (current_time - + elapsed_time = (curr_time - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + stale_partition_unavailability_check = partition_health_info.unavailability_info[UNAVAILABLE_INTERVAL] # check if the partition key range is still unavailable return ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) or (current_health_status == UNHEALTHY_TENTATIVE and elapsed_time > INITIAL_UNAVAILABLE_TIME)) @@ -122,10 +139,7 @@ def _transition_health_status_on_failure( if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: # healthy -> unhealthy tentative partition_health_info = _PartitionHealthInfo() - partition_health_info.unavailability_info = { - LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, - HEALTH_STATUS: UNHEALTHY_TENTATIVE - } + partition_health_info.transition_health_status(UNHEALTHY_TENTATIVE, current_time) self.pk_range_wrapper_to_health_info[pk_range_wrapper] = { location: partition_health_info } @@ -133,21 +147,17 @@ def _transition_health_status_on_failure( region_to_partition_health = self.pk_range_wrapper_to_health_info[pk_range_wrapper] if location in region_to_partition_health and region_to_partition_health[location].unavailability_info: # healthy tentative -> unhealthy + region_to_partition_health[location].transition_health_status(UNHEALTHY, current_time) # if the operation type is not empty, we are in the healthy tentative state - region_to_partition_health[location].unavailability_info[HEALTH_STATUS] = UNHEALTHY - # reset the last unavailability check time stamp - region_to_partition_health[location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ - = current_time else: # healthy -> unhealthy tentative # if the operation type is empty, we are in the unhealthy tentative state partition_health_info = _PartitionHealthInfo() - partition_health_info.unavailability_info = { - LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, - HEALTH_STATUS: UNHEALTHY_TENTATIVE - } + partition_health_info.transition_health_status(UNHEALTHY_TENTATIVE, current_time) self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = partition_health_info + + def _transition_health_status_on_success( self, pk_range_wrapper: PartitionKeyRangeWrapper, @@ -164,18 +174,16 @@ def _check_stale_partition_info( ) -> None: current_time = current_time_millis() - stale_partition_unavailability_check = int(os.environ.get(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, - Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 if pk_range_wrapper in self.pk_range_wrapper_to_health_info: for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): if partition_health_info.unavailability_info: - if self._should_mark_healthy_tentative(partition_health_info, current_time): + if _should_mark_healthy_tentative(partition_health_info, current_time): # unhealthy or unhealthy tentative -> healthy tentative + # only one request should be used to recover with (self.stale_partiton_lock): - if self._should_mark_healthy_tentative(partition_health_info, current_time): + if _should_mark_healthy_tentative(partition_health_info, current_time): # this will trigger one attempt to recover - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] = current_time - partition_health_info.unavailability_info[HEALTH_STATUS] = UNHEALTHY + partition_health_info.transition_health_status(UNHEALTHY, current_time) request.healthy_tentative_location = location if current_time - self.last_refresh > REFRESH_INTERVAL: From 0a0df4f3fb07c690d539bdf56c15e860eae86bfd Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 21 Apr 2025 10:40:09 -0700 Subject: [PATCH 095/152] fix pylint --- .../azure/cosmos/_cosmos_client_connection.py | 4 +++- ...artition_endpoint_manager_circuit_breaker_core.py | 3 +-- .../azure/cosmos/_partition_health_tracker.py | 4 ++-- .../azure-cosmos/azure/cosmos/_request_object.py | 2 +- .../cosmos/_routing/aio/routing_map_provider.py | 9 +++++---- .../cosmos/aio/_cosmos_client_connection_async.py | 12 ++++++------ .../cosmos/aio/_global_endpoint_manager_async.py | 4 ++-- .../azure/cosmos/aio/_retry_utility_async.py | 5 +++-- 8 files changed, 23 insertions(+), 20 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 8e5736b006c4..3db8fe922346 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2194,7 +2194,9 @@ def DeleteAllItemsByPartitionKey( collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options) - request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, headers) + request_params = RequestObject(http_constants.ResourceType.PartitionKey, + documents._OperationType.Delete, + headers) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 457fae94a2ec..e38164d1e193 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -87,7 +87,7 @@ def check_stale_partition_info( request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper ) -> None: - self.partition_health_tracker._check_stale_partition_info(request, pk_range_wrapper) + self.partition_health_tracker.check_stale_partition_info(request, pk_range_wrapper) def add_excluded_locations_to_request( @@ -114,4 +114,3 @@ def record_success( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType # pylint: disable=protected-access location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) self.partition_health_tracker.add_success(pk_range_wrapper, endpoint_operation_type, location) - diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index aa2ffddee5f6..673e653ca9fc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -167,7 +167,7 @@ def _transition_health_status_on_success( # healthy tentative -> healthy self.pk_range_wrapper_to_health_info[pk_range_wrapper][location].unavailability_info = {} - def _check_stale_partition_info( + def check_stale_partition_info( self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper @@ -180,7 +180,7 @@ def _check_stale_partition_info( if _should_mark_healthy_tentative(partition_health_info, current_time): # unhealthy or unhealthy tentative -> healthy tentative # only one request should be used to recover - with (self.stale_partiton_lock): + with self.stale_partiton_lock: if _should_mark_healthy_tentative(partition_health_info, current_time): # this will trigger one attempt to recover partition_health_info.transition_health_status(UNHEALTHY, current_time) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 8ba07bd0b29e..1ddce7880a3d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,7 +21,7 @@ """Represents a request object. """ -from typing import Optional, Mapping, Any, Dict, Set, List +from typing import Optional, Mapping, Any, Dict, List from . import http_constants class RequestObject(object): # pylint: disable=too-many-instance-attributes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index 951f56a89400..680191ac2228 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -61,11 +61,11 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** :rtype: list """ collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + await self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) - async def initialize_collection_routing_map_if_needed( + async def init_collection_routing_map_if_needed( self, collection_link: str, collection_id: str, @@ -91,9 +91,10 @@ async def get_range_by_partition_key_range_id( **kwargs: Dict[str, Any] ) -> Dict[str, Any]: collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + await self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) - return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id(partition_key_range_id) + return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id( + partition_key_range_id) @staticmethod def _discard_parent_ranges(partitionKeyRanges): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 5df1bf70d44e..3aa7d65c39a7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2923,12 +2923,12 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: isPrefixPartitionQuery = False partition_key_definition = None if cont_prop and partition_key: - pk_properties = cont_prop["partitionKey"] - partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"]) - if partition_key_definition.kind == "MultiHash" and \ - (isinstance(partition_key, List) and \ - len(partition_key_definition['paths']) != len(partition_key)): - isPrefixPartitionQuery = True + pk_properties = cont_prop["partitionKey"] + partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"]) + if partition_key_definition.kind == "MultiHash" and \ + (isinstance(partition_key, List) and \ + len(partition_key_definition['paths']) != len(partition_key)): + isPrefixPartitionQuery = True # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 6a61cacc5fc1..e4a3ef96161f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -103,7 +103,7 @@ async def refresh_endpoint_list(self, database_account, **kwargs): try: await self.refresh_task self.refresh_task = None - except (Exception, asyncio.CancelledError) as exception: + except (Exception, asyncio.CancelledError) as exception: #pylint: disable=broad-exception-caught logger.exception("Health check task failed: %s", exception) if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True @@ -213,5 +213,5 @@ async def close(self): self.refresh_task.cancel() try: await self.refresh_task - except (Exception, asyncio.CancelledError) : + except (Exception, asyncio.CancelledError) : #pylint: disable=broad-exception-caught pass diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 71e78dd1df86..5ce2acabaa78 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -25,7 +25,7 @@ import json import time -from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError +from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import AsyncRetryPolicy from .. import _default_retry_policy, _database_account_retry_policy @@ -299,7 +299,8 @@ async def send(self, request): raise err except ServiceResponseError as err: retry_error = err - if (_has_database_account_header(request.http_request.headers) or request_params.healthy_tentative_location): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err # Since this is ClientConnectionError, it is safe to be retried on both read and write requests try: From b7effeea48ba8014efd8b8d77248b10033b7c61f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 21 Apr 2025 11:55:41 -0700 Subject: [PATCH 096/152] add sync tests --- sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py | 383 ++++++++++++++++++ .../azure-cosmos/tests/test_ppcb_mm_async.py | 28 +- 2 files changed, 397 insertions(+), 14 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py new file mode 100644 index 000000000000..91a3ca82e209 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py @@ -0,0 +1,383 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import os +import unittest +import uuid + +import pytest +from azure.core.pipeline import HttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey, _location_cache, _partition_health_tracker +from azure.cosmos import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from tests._fault_injection_transport import FaultInjectionTransport +from tests.test_ppcb_mm_async import DELETE, CREATE, UPSERT, REPLACE, PATCH, BATCH, validate_response_uri, READ, \ + QUERY_PK, QUERY, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS, REGION_1, REGION_2, \ + write_operations_and_errors, validate_unhealthy_partitions, read_operations_and_errors, PK_VALUE, operations, \ + create_doc + + +@pytest.fixture(scope="class", autouse=True) +def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPerPartitionCircuitBreakerMM.host, + TestPerPartitionCircuitBreakerMM.master_key) + created_database = client.get_database_client(TestPerPartitionCircuitBreakerMM.TEST_DATABASE_ID) + created_database.create_container(TestPerPartitionCircuitBreakerMM.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + # allow some time for the container to be created as this method is in different event loop + # sleep(3) + yield + + created_database.delete_container(TestPerPartitionCircuitBreakerMM.TEST_CONTAINER_SINGLE_PARTITION_ID) + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == CREATE: + resp = fault_injection_container.create_item(body=doc) + elif operation == UPSERT: + resp = fault_injection_container.upsert_item(body=doc) + elif operation == REPLACE: + container.create_item(body=doc) + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + resp = fault_injection_container.replace_item(item=doc['id'], body=new_doc) + elif operation == DELETE: + container.create_item(body=doc) + resp = fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == PATCH: + container.create_item(body=doc) + operations = [{"op": "incr", "path": "/company", "value": 3}] + resp = fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) + elif operation == BATCH: + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + resp = fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) + # this will need to be emulator only + # elif operation == "delete_all_items_by_partition_key": + # await container.create_item(body=doc) + # await container.create_item(body=doc) + # await container.create_item(body=doc) + # resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + if resp: + validate_response_uri(resp, expected_uri) + +def perform_read_operation(operation, container, doc_id, pk, expected_uri): + if operation == READ: + read_resp = container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + elif operation == QUERY_PK: + # partition key filtered query + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == QUERY: + # cross partition query + query = "SELECT * FROM c WHERE c.id = @id" + for item in container.query_items(query=query): + assert item['id'] == doc_id + elif operation == CHANGE_FEED: + for _ in container.query_items_change_feed(): + pass + elif operation == CHANGE_FEED_PK: + # partition key filtered change feed + for _ in container.query_items_change_feed(partition_key=pk): + pass + elif operation == CHANGE_FEED_EPK: + # partition key filtered by feed range + feed_range = container.feed_range_from_partition_key(partition_key=pk) + async for _ in container.query_items_change_feed(feed_range=feed_range): + pass + elif operation == READ_ALL_ITEMS: + async for _ in container.read_all_items(): + pass + +@pytest.mark.cosmosMultiRegion +@pytest.mark.usefixtures("setup_teardown") +class TestPerPartitionCircuitBreakerMM: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + def setup_method_with_custom_transport(self, custom_transport: HttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, + preferred_locations=[REGION_1, REGION_2], + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_consecutive_failure_threshold(self, setup_teardown, write_operation, error): + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(4): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + # writes should now succeed because going to the other region + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation, error): + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error=error) + + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # create a document to read + container.create_item(body=doc) + + # reads should fail over and only the relevant partition should be marked as unavailable + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + # partition should not have been marked unavailable after one error + validate_unhealthy_partitions(global_endpoint_manager, 0) + + for i in range(10): + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + # remove faults and reduce initial recover time and perform a read + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_failure_rate_threshold(self, setup_teardown, write_operation, error): + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error=error) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 1) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error): + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error) + container.upsert_item(body=doc) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + if isinstance(error, ServiceResponseError): + # service response error retries in region 3 additional times before failing over + print("Service response error") + print("num operations") + num_operations = 2 + else: + num_operations = 8 + for i in range(num_operations): + validate_unhealthy_partitions(global_endpoint_manager, 0) + # read will fail and retry in other region + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 + + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + async def setup_info(self, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransport() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + fault_injection_container = custom_setup['col'] + setup = self.setup_method_with_custom_transport(default_endpoint=self.host) + container = setup['col'] + return container, doc, expected_uri, uri_down, fault_injection_container, custom_transport + + @pytest.mark.parametrize("read_operation, write_operation", operations()) + async def test_service_request_error_async(self, read_operation, write_operation): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransport() + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_region_down()) + doc = {'id': str(uuid.uuid4()), + 'pk': PK_VALUE, + 'name': 'sample document', + 'key': 'value'} + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + fault_injection_container = custom_setup['col'] + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) + container = setup['col'] + container.upsert_item(body=doc) + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + + # recover partition + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_region_down()) + + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 29c4ecfdcac4..fde210367ecd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -34,6 +34,7 @@ REPLACE = "replace" PATCH = "patch" DELETE = "delete" +PK_VALUE = "pk1" COLLECTION = "created_collection" @@ -69,6 +70,12 @@ def write_operations_and_errors(): return params +def create_doc(): + return {'id': str(uuid.uuid4()), + 'pk': PK_VALUE, + 'name': 'sample document', + 'key': 'value'} + def read_operations_and_errors(): read_operations = [READ, QUERY, QUERY_PK, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS] errors = [] @@ -210,15 +217,6 @@ async def setup_method_with_custom_transport(self, custom_transport: AioHttpTran container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) return {"client": client, "db": db, "col": container} - async def setup_method(self, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, self.master_key, - preferred_locations=[REGION_1, REGION_2], - multiple_write_locations=True, - **kwargs) - db = client.get_database_client(self.TEST_DATABASE_ID) - container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) - return {"client": client, "db": db, "col": container} - @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) @@ -233,7 +231,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - setup = await self.setup_method(default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) container = setup['col'] fault_injection_container = custom_setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager @@ -304,7 +302,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) fault_injection_container = custom_setup['col'] - setup = await self.setup_method(default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) container = setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager @@ -370,7 +368,7 @@ async def test_write_failure_rate_threshold_async(self, setup_teardown, write_op custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) fault_injection_container = custom_setup['col'] - setup = await self.setup_method(default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) container = setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing @@ -429,7 +427,7 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) fault_injection_container = custom_setup['col'] - setup = await self.setup_method(default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) container = setup['col'] await container.upsert_item(body=doc) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager @@ -439,6 +437,8 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper try: if isinstance(error, ServiceResponseError): # service response error retries in region 3 additional times before failing over + print("Service response error") + print("num operations") num_operations = 2 else: num_operations = 8 @@ -481,7 +481,7 @@ async def test_service_request_error_async(self, read_operation, write_operation 'key': 'value'} custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) fault_injection_container = custom_setup['col'] - setup = await self.setup_method(default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) container = setup['col'] await container.upsert_item(body=doc) await perform_read_operation(read_operation, From 5a235c501615ee82784764ebd9b4148576caf13f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 21 Apr 2025 14:41:54 -0700 Subject: [PATCH 097/152] refactor tests --- sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py | 60 +++--- .../azure-cosmos/tests/test_ppcb_mm_async.py | 172 ++++++------------ 2 files changed, 87 insertions(+), 145 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py index 91a3ca82e209..36fa264405aa 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py @@ -103,10 +103,10 @@ def perform_read_operation(operation, container, doc_id, pk, expected_uri): elif operation == CHANGE_FEED_EPK: # partition key filtered by feed range feed_range = container.feed_range_from_partition_key(partition_key=pk) - async for _ in container.query_items_change_feed(feed_range=feed_range): + for _ in container.query_items_change_feed(feed_range=feed_range): pass elif operation == READ_ALL_ITEMS: - async for _ in container.read_all_items(): + for _ in container.read_all_items(): pass @pytest.mark.cosmosMultiRegion @@ -128,7 +128,11 @@ def setup_method_with_custom_transport(self, custom_transport: HttpTransport, de @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) def test_write_consecutive_failure_threshold(self, setup_teardown, write_operation, error): - container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: @@ -179,7 +183,11 @@ def test_write_consecutive_failure_threshold(self, setup_teardown, write_operati @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation, error): - container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error=error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager @@ -225,7 +233,11 @@ def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) def test_write_failure_rate_threshold(self, setup_teardown, write_operation, error): - container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error=error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 @@ -263,7 +275,11 @@ def test_write_failure_rate_threshold(self, setup_teardown, write_operation, err @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error): - container, doc, expected_uri, uri_down, fault_injection_container, custom_transport = self.setup_info(error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + ) + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) container.upsert_item(body=doc) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing @@ -272,8 +288,6 @@ def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error try: if isinstance(error, ServiceResponseError): # service response error retries in region 3 additional times before failing over - print("Service response error") - print("num operations") num_operations = 2 else: num_operations = 8 @@ -297,7 +311,7 @@ def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - async def setup_info(self, error): + def setup_info(self, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) custom_transport = FaultInjectionTransport() @@ -306,35 +320,19 @@ async def setup_info(self, error): predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and FaultInjectionTransport.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, - lambda r: FaultInjectionTransport.error_after_delay( - 0, - error - )) + error) custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) fault_injection_container = custom_setup['col'] - setup = self.setup_method_with_custom_transport(default_endpoint=self.host) + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) container = setup['col'] - return container, doc, expected_uri, uri_down, fault_injection_container, custom_transport + return container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate @pytest.mark.parametrize("read_operation, write_operation", operations()) - async def test_service_request_error_async(self, read_operation, write_operation): + def test_service_request_error(self, read_operation, write_operation): # the region should be tried 4 times before failing over and mark the partition as unavailable # the region should not be marked as unavailable - expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) - uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) - custom_transport = FaultInjectionTransport() - predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and - FaultInjectionTransport.predicate_targets_region(r, uri_down)) - custom_transport.add_fault(predicate, - lambda r: FaultInjectionTransport.error_region_down()) - doc = {'id': str(uuid.uuid4()), - 'pk': PK_VALUE, - 'name': 'sample document', - 'key': 'value'} - custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - fault_injection_container = custom_setup['col'] - setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) - container = setup['col'] + error_lambda = lambda r: FaultInjectionTransport.error_region_down() + container, doc, expected_uri, uri_down, fault_injection_container, custom_transport, predicate = self.setup_info(error_lambda) container.upsert_item(body=doc) perform_read_operation(read_operation, fault_injection_container, diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index fde210367ecd..418bd113469e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -219,19 +219,11 @@ async def setup_method_with_custom_transport(self, custom_transport: AioHttpTran @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): - expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) - uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) - custom_transport = FaultInjectionTransportAsync() - pk_value = "pk1" - predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) - custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error - ))) - - custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) container = setup['col'] fault_injection_container = custom_setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager @@ -241,7 +233,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w container, fault_injection_container, str(uuid.uuid4()), - pk_value, + PK_VALUE, expected_uri) assert exc_info.value == error @@ -254,7 +246,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w container, fault_injection_container, str(uuid.uuid4()), - pk_value, + PK_VALUE, expected_uri) assert exc_info.value == error @@ -263,7 +255,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w container, fault_injection_container, str(uuid.uuid4()), - pk_value, + PK_VALUE, expected_uri) validate_unhealthy_partitions(global_endpoint_manager, 1) @@ -276,44 +268,48 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w container, fault_injection_container, str(uuid.uuid4()), - pk_value, + PK_VALUE, uri_down) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) await cleanup_method([custom_setup, setup]) - @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): + async def setup_info(self, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) - custom_transport = FaultInjectionTransportAsync() - id_value = str(uuid.uuid4()) - document_definition = {'id': id_value, - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} + custom_transport = FaultInjectionTransportAsync() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) - custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 0, - error - ))) - + custom_transport.add_fault(predicate, + error) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - fault_injection_container = custom_setup['col'] setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # create a document to read - await container.create_item(body=document_definition) + await container.create_item(body=doc) # reads should fail over and only the relevant partition should be marked as unavailable await perform_read_operation(read_operation, fault_injection_container, - document_definition['id'], - document_definition['pk'], + doc['id'], + doc['pk'], expected_uri) # partition should not have been marked unavailable after one error validate_unhealthy_partitions(global_endpoint_manager, 0) @@ -321,8 +317,8 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re for i in range(10): await perform_read_operation(read_operation, fault_injection_container, - document_definition['id'], - document_definition['pk'], + doc['id'], + doc['pk'], expected_uri) # the partition should have been marked as unavailable after breaking read threshold @@ -339,8 +335,8 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re try: await perform_read_operation(read_operation, fault_injection_container, - document_definition['id'], - document_definition['pk'], + doc['id'], + doc['pk'], uri_down) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time @@ -349,27 +345,13 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_write_failure_rate_threshold_async(self, setup_teardown, write_operation, error): - expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) - uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) - custom_transport = FaultInjectionTransportAsync() - # two documents targeted to same partition, one will always fail and the other will succeed - pk_value = "pk1" - doc = {'id': str(uuid.uuid4()), - 'pk': pk_value, - 'name': 'sample document', - 'key': 'value'} - predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) - custom_transport.add_fault(predicate, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 0, - error - ))) - - custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - fault_injection_container = custom_setup['col'] - setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) container = setup['col'] + fault_injection_container = custom_setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 @@ -394,7 +376,7 @@ async def test_write_failure_rate_threshold_async(self, setup_teardown, write_op container, fault_injection_container, str(uuid.uuid4()), - pk_value, + PK_VALUE, expected_uri) assert exc_info.value == error @@ -408,27 +390,13 @@ async def test_write_failure_rate_threshold_async(self, setup_teardown, write_op @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) async def test_read_failure_rate_threshold_async(self, setup_teardown, read_operation, error): - expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) - uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) - custom_transport = FaultInjectionTransportAsync() - # two documents targeted to same partition, one will always fail and the other will succeed - pk_value = "pk1" - doc = {'id': str(uuid.uuid4()), - 'pk': pk_value, - 'name': 'sample document', - 'key': 'value'} - predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) - custom_transport.add_fault(predicate, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 0, - error - ))) - - custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - fault_injection_container = custom_setup['col'] - setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) container = setup['col'] + fault_injection_container = custom_setup['col'] await container.upsert_item(body=doc) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing @@ -437,8 +405,6 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper try: if isinstance(error, ServiceResponseError): # service response error retries in region 3 additional times before failing over - print("Service response error") - print("num operations") num_operations = 2 else: num_operations = 8 @@ -448,7 +414,7 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper await perform_read_operation(read_operation, fault_injection_container, doc['id'], - pk_value, + PK_VALUE, expected_uri) if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): # these operations are cross partition so they would mark both partitions as unavailable @@ -467,27 +433,15 @@ async def test_read_failure_rate_threshold_async(self, setup_teardown, read_oper async def test_service_request_error_async(self, read_operation, write_operation): # the region should be tried 4 times before failing over and mark the partition as unavailable # the region should not be marked as unavailable - expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) - uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) - custom_transport = FaultInjectionTransportAsync() - predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) - custom_transport.add_fault(predicate, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) - pk_value = "pk1" - doc = {'id': str(uuid.uuid4()), - 'pk': pk_value, - 'name': 'sample document', - 'key': 'value'} - custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - fault_injection_container = custom_setup['col'] - setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down()) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) container = setup['col'] + fault_injection_container = custom_setup['col'] await container.upsert_item(body=doc) await perform_read_operation(read_operation, fault_injection_container, doc['id'], - pk_value, + PK_VALUE, expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager @@ -504,7 +458,7 @@ async def test_service_request_error_async(self, read_operation, write_operation await perform_read_operation(read_operation, fault_injection_container, doc['id'], - pk_value, + PK_VALUE, uri_down) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time @@ -517,7 +471,7 @@ async def test_service_request_error_async(self, read_operation, write_operation container, fault_injection_container, str(uuid.uuid4()), - pk_value, + PK_VALUE, expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager @@ -532,22 +486,12 @@ async def test_service_request_error_async(self, read_operation, write_operation # send 5 write concurrent requests when trying to recover # verify that only one failed async def test_recovering_only_fails_one_requests_async(self): - uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) - custom_transport = FaultInjectionTransportAsync() - predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) - custom_transport.add_fault(predicate, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, CosmosHttpResponseError( - status_code=502, - message="Some envoy error.")))) - pk_value = "pk1" - custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, CosmosHttpResponseError( + status_code=502, + message="Some envoy error."))) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) fault_injection_container = custom_setup['col'] - doc = {'id': str(uuid.uuid4()), - 'pk': pk_value, - 'name': 'sample document', - 'key': 'value'} - for i in range(5): with pytest.raises(CosmosHttpResponseError): await fault_injection_container.create_item(body=doc) @@ -558,7 +502,7 @@ async def test_recovering_only_fails_one_requests_async(self): async def concurrent_upsert(): nonlocal number_of_errors doc = {'id': str(uuid.uuid4()), - 'pk': pk_value, + 'pk': PK_VALUE, 'name': 'sample document', 'key': 'value'} try: From 0f857168f0093ca66a70cecd783e79f8f9fda80c Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 21 Apr 2025 15:12:24 -0700 Subject: [PATCH 098/152] sync changes --- .../azure/cosmos/aio/_cosmos_client_connection_async.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 3aa7d65c39a7..665c019f4fc7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2875,11 +2875,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() - cont_prop = kwargs.pop("containerProperties", None) - - if cont_prop: - cont_prop = await cont_prop() - options["containerRID"] = cont_prop["_rid"] # Copy to make sure that default_headers won't be changed. if query is None: @@ -2922,7 +2917,10 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: partition_key = options.get("partitionKey", None) isPrefixPartitionQuery = False partition_key_definition = None + cont_prop = kwargs.pop("containerProperties", None) if cont_prop and partition_key: + cont_prop = await cont_prop() + options["containerRID"] = cont_prop["_rid"] pk_properties = cont_prop["partitionKey"] partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"]) if partition_key_definition.kind == "MultiHash" and \ From 397d087a4824bb59d10234099c41dad84c15bac2 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 22 Apr 2025 10:51:37 -0700 Subject: [PATCH 099/152] sync changes and cleanup --- ...tition_endpoint_manager_circuit_breaker.py | 46 +++++++++++++------ .../azure/cosmos/_request_object.py | 1 + .../azure/cosmos/_retry_utility.py | 22 +++------ .../cosmos/_routing/routing_map_provider.py | 45 ++++++++++++------ .../azure/cosmos/_synchronized_request.py | 2 + .../azure/cosmos/aio/_container.py | 3 -- ..._endpoint_manager_circuit_breaker_async.py | 3 +- .../azure-cosmos/azure/cosmos/container.py | 38 ++++++++------- sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py | 3 +- 9 files changed, 96 insertions(+), 67 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 608f63c06ecf..c36367ac92ac 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -23,19 +23,18 @@ """ from typing import TYPE_CHECKING +from azure.cosmos import PartitionKey from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject -from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range from azure.cosmos.http_constants import HttpHeaders if TYPE_CHECKING: from azure.cosmos._cosmos_client_connection import CosmosClientConnection - - class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): """ This internal class implements the logic for partition endpoint management for @@ -54,17 +53,37 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key + print(request.headers) target_container_link = None - for container_link, properties in self.Client._container_properties_cache.items(): # pylint: disable=protected-access + partition_key = None + # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache + for container_link, properties in self.Client._container_properties_cache.items(): if properties["_rid"] == container_rid: target_container_link = container_link - if not target_container_link: + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) + + if not target_container_link or not partition_key: raise RuntimeError("Illegal state: the container cache is not properly initialized.") - pk_range = (self.Client._routing_map_provider # pylint: disable=protected-access - .get_overlapping_ranges(target_container_link, partition_key)) - return PartitionKeyRangeWrapper(pk_range, container_rid) + + if request.headers.get(HttpHeaders.PartitionKey): + partition_key_value = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] + partition_ranges = (self.Client._routing_map_provider + .get_overlapping_ranges(target_container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + elif request.headers.get(HttpHeaders.PartitionKeyRangeID): + pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + range =(self.Client._routing_map_provider + .get_range_by_partition_key_range_id(target_container_link, pk_range_id)) + partition_range = Range.PartitionKeyRangeToRange(range) + else: + raise RuntimeError("Illegal state: the request does not contain partition information.") + + return PartitionKeyRangeWrapper(partition_range, container_rid) + def record_failure( self, @@ -74,12 +93,13 @@ def record_failure( pk_range_wrapper = self.create_pk_range_wrapper(request) self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) - def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> str: + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): if self.is_circuit_breaker_applicable(request): + self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) - return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request, - pk_range_wrapper) + return (super(_GlobalPartitionEndpointManagerForCircuitBreaker, self) + .resolve_service_endpoint(request, pk_range_wrapper)) def mark_partition_unavailable( self, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 1ddce7880a3d..d085ddb06823 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -25,6 +25,7 @@ from . import http_constants class RequestObject(object): # pylint: disable=too-many-instance-attributes + # TODO: @tvaron3 add container link here def __init__( self, resource_type: str, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 7d27885f10db..70c1eacccb89 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -63,7 +63,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) database_account_retry_policy = _database_account_retry_policy.DatabaseAccountRetryPolicy( client.connection_policy @@ -212,6 +212,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin if not database_account_retry_policy.ShouldRetry(e): raise e else: + global_endpoint_manager.record_failure(args[0]) _handle_service_request_retries(client, service_request_retry_policy, e, *args) except ServiceResponseError as e: @@ -322,9 +323,10 @@ def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies # This logic is based on the _retry.py file from azure-core - if not _has_database_account_header(request.http_request.headers): - global_endpoint_manager.record_failure(request_params) + if (not _has_database_account_header(request.http_request.headers) + and not request_params.healthy_tentative_location): if retry_settings['connect'] > 0: + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) @@ -334,21 +336,11 @@ def send(self, request): retry_error = err # Only read operations can be safely retried with ServiceResponseError if (not _has_read_retryable_headers(request.http_request.headers) or - _has_database_account_header(request.http_request.headers)): + _has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err - global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: - retry_active = self.increment(retry_settings, response=request, error=err) - if retry_active: - self.sleep(retry_settings, request.context.transport) - continue - raise err - except AzureError as err: - retry_error = err - if _has_database_account_header(request.http_request.headers): - raise err - if self._is_method_retryable(retry_settings, request.http_request): global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 8dacb5190e07..99ce35de3bff 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -22,6 +22,7 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +from typing import Dict, Any from .. import _base from .collection_routing_map import CollectionRoutingMap @@ -50,6 +51,25 @@ def __init__(self, client): # keeps the cached collection routing map by collection id self._collection_routing_map_by_item = {} + def initialize_collection_routing_map_if_needed( + self, + collection_link: str, + collection_id: str, + **kwargs: Dict[str, Any] + ): + client = self._documentClient + collection_routing_map = self._collection_routing_map_by_item.get(collection_id) + if collection_routing_map is None: + collection_pk_ranges = list(client._ReadPartitionKeyRanges(collection_link, **kwargs)) + # for large collections, a split may complete between the read partition key ranges query page responses, + # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need + # to discard the parent ranges to have a valid routing map. + collection_pk_ranges = PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges) + collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( + [(r, True) for r in collection_pk_ranges], collection_id + ) + self._collection_routing_map_by_item[collection_id] = collection_routing_map + def get_overlapping_ranges(self, collection_link, partition_key_ranges, **kwargs): """Given a partition key range and a collection, return the list of overlapping partition key ranges. @@ -59,22 +79,21 @@ def get_overlapping_ranges(self, collection_link, partition_key_ranges, **kwargs :return: List of overlapping partition key ranges. :rtype: list """ - cl = self._documentClient + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) + + def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Dict[str, Any]: collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) - collection_routing_map = self._collection_routing_map_by_item.get(collection_id) - if collection_routing_map is None: - collection_pk_ranges = list(cl._ReadPartitionKeyRanges(collection_link, **kwargs)) - # for large collections, a split may complete between the read partition key ranges query page responses, - # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need - # to discard the parent ranges to have a valid routing map. - collection_pk_ranges = PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges) - collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( - [(r, True) for r in collection_pk_ranges], collection_id - ) - self._collection_routing_map_by_item[collection_id] = collection_routing_map - return collection_routing_map.get_overlapping_ranges(partition_key_ranges) + return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id(partition_key_range_id) @staticmethod def _discard_parent_ranges(partitionKeyRanges): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 43516430bdb6..9e7e38b31322 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -90,6 +90,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin # Every request tries to perform a refresh client_timeout = kwargs.get('timeout') start_time = time.time() + if request_params.healthy_tentative_location: + read_timeout = connection_policy.RecoveryReadTimeout if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 12fa8dfefc74..d85266cccb13 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -164,8 +164,6 @@ async def read( range statistics in response headers. :keyword bool populate_quota_info: Enable returning collection storage quota information in response headers. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. - :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations - in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -231,7 +229,6 @@ async def create_item( If all preferred locations were excluded, primary/hub location will be used. This excluded_location will override existing excluded_locations in client level. :keyword response_hook: A callable invoked with the response metadata. - :paramtype response_hook: Callable[[Dict[str, str], Dict[str, Any]], None] :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 3f6f4c333808..100cd672cb65 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -77,6 +77,8 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyR range = await (self.client._routing_map_provider .get_range_by_partition_key_range_id(target_container_link, pk_range_id)) partition_range = Range.PartitionKeyRangeToRange(range) + else: + raise RuntimeError("Illegal state: the request does not contain partition information.") return PartitionKeyRangeWrapper(partition_range, container_rid) @@ -96,7 +98,6 @@ def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: Par self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) - # Todo: @tvaron3 think about how to switch healthy tentative unhealthy tentative return (super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self) .resolve_service_endpoint(request, pk_range_wrapper)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index f2c55ddff4e8..083bc7b074b4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -274,8 +274,8 @@ def read_item( # pylint:disable=docstring-missing-param if max_integrated_cache_staleness_in_ms is not None: validate_cache_staleness_value(max_integrated_cache_staleness_in_ms) request_options["maxIntegratedCacheStaleness"] = max_integrated_cache_staleness_in_ms - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) @distributed_trace @@ -334,8 +334,8 @@ def read_all_items( # pylint:disable=docstring-missing-param if response_hook and hasattr(response_hook, "clear"): response_hook.clear() - if self.container_link in self.__get_client_container_caches(): - feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] items = self.client_connection.ReadItems( collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs) @@ -689,8 +689,8 @@ def query_items( # pylint:disable=docstring-missing-param feed_options["responseContinuationTokenLimitInKb"] = continuation_token_limit if response_hook and hasattr(response_hook, "clear"): response_hook.clear() - if self.container_link in self.__get_client_container_caches(): - feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] items = self.client_connection.QueryItems( database_or_container_link=self.container_link, query=query if parameters is None else {"query": query, "parameters": parameters}, @@ -788,8 +788,8 @@ def replace_item( # pylint:disable=docstring-missing-param ) request_options["populateQueryMetrics"] = populate_query_metrics - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.ReplaceItem( document_link=item_link, new_document=body, options=request_options, **kwargs) return result @@ -867,8 +867,8 @@ def upsert_item( # pylint:disable=docstring-missing-param DeprecationWarning, ) request_options["populateQueryMetrics"] = populate_query_metrics - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.UpsertItem( database_or_container_link=self.container_link, @@ -963,8 +963,8 @@ def create_item( # pylint:disable=docstring-missing-param request_options["populateQueryMetrics"] = populate_query_metrics if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.CreateItem( database_or_container_link=self.container_link, document=body, options=request_options, **kwargs) return result @@ -1045,8 +1045,8 @@ def patch_item( if filter_predicate is not None: request_options["filterPredicate"] = filter_predicate - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) result = self.client_connection.PatchItem( document_link=item_link, operations=patch_operations, options=request_options, **kwargs) @@ -1077,8 +1077,6 @@ def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. - :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations - in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. If all preferred locations were excluded, primary/hub location will be used. This excluded_location will override existing excluded_locations in client level. :keyword response_hook: A callable invoked with the response metadata. @@ -1190,8 +1188,8 @@ def delete_item( # pylint:disable=docstring-missing-param request_options["preTriggerInclude"] = pre_trigger_include if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] document_link = self._get_document_link(item) self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs) @@ -1466,8 +1464,8 @@ def delete_all_items_by_partition_key( request_options = build_options(kwargs) # regardless if partition key is valid we set it as invalid partition keys are set to a default empty value request_options["partitionKey"] = self._set_partition_key(partition_key) - if self.container_link in self.__get_client_container_caches(): - request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] + self._get_properties() + request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] self.client_connection.DeleteAllItemsByPartitionKey( collection_link=self.container_link, options=request_options, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py index 36fa264405aa..447cd88996db 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py @@ -5,7 +5,6 @@ import uuid import pytest -from azure.core.pipeline import HttpTransport from azure.core.exceptions import ServiceResponseError import test_config @@ -117,7 +116,7 @@ class TestPerPartitionCircuitBreakerMM: TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) - def setup_method_with_custom_transport(self, custom_transport: HttpTransport, default_endpoint=host, **kwargs): + def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, preferred_locations=[REGION_1, REGION_2], multiple_write_locations=True, From 9ee3e8776b31b7e3f9fd8d41768a4befc960987d Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 22 Apr 2025 16:24:19 -0700 Subject: [PATCH 100/152] container cache changes --- .../azure/cosmos/_cosmos_client_connection.py | 59 ++++++++++++++++--- ...tition_endpoint_manager_circuit_breaker.py | 2 +- ...n_endpoint_manager_circuit_breaker_core.py | 1 - .../azure/cosmos/_request_object.py | 2 + .../azure/cosmos/_retry_utility.py | 3 +- .../aio/_cosmos_client_connection_async.py | 1 + .../azure-cosmos/azure/cosmos/container.py | 11 +++- sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py | 3 +- 8 files changed, 67 insertions(+), 15 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3db8fe922346..fbf849fdd4f3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -257,6 +257,7 @@ def _set_container_properties_cache(self, container_link: str, properties: Optio :type properties: Optional[Dict[str, Any]]""" if properties: self.__container_properties_cache[container_link] = properties + self.__container_properties_cache[properties["_rid"]] = properties else: self.__container_properties_cache[container_link] = {} @@ -1294,7 +1295,14 @@ def CreateItem( if base.IsItemContainerLink(database_or_container_link): options = self._AddPartitionKey(database_or_container_link, document, options) - return self.Create(document, path, "docs", collection_id, None, options, **kwargs) + return self.Create(document, + path, + "docs", + collection_id, + None, + options, + container_link=database_or_container_link, + **kwargs) def UpsertItem( self, @@ -1330,7 +1338,14 @@ def UpsertItem( collection_id, document, path = self._GetContainerIdWithPathForItem( database_or_container_link, document, options ) - return self.Upsert(document, path, "docs", collection_id, None, options, **kwargs) + return self.Upsert(document, + path, + "docs", + collection_id, + None, + options, + container_link=database_or_container_link, + **kwargs) PartitionResolverErrorMessage = ( "Couldn't find any partition resolvers for the database link provided. " @@ -1977,6 +1992,7 @@ def ReplaceItem( document_link: str, new_document: Dict[str, Any], options: Optional[Mapping[str, Any]] = None, + container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Replaces a document and returns it. @@ -1986,6 +2002,7 @@ def ReplaceItem( :param dict new_document: :param dict options: The request options for the request. + :param str container_link: :return: The new Document. @@ -2011,12 +2028,20 @@ def ReplaceItem( collection_link = base.GetItemContainerLink(document_link) options = self._AddPartitionKey(collection_link, new_document, options) - return self.Replace(new_document, path, "docs", document_id, None, options, **kwargs) + return self.Replace(new_document, + path, + "docs", + document_id, + None, + options, + container_link=container_link, + **kwargs) def PatchItem( self, document_link: str, operations: List[Dict[str, Any]], + container_link: str, options: Optional[Mapping[str, Any]] = None, **kwargs: Any ) -> CosmosDict: @@ -2024,6 +2049,7 @@ def PatchItem( :param str document_link: The link to the document. :param list operations: The operations for the patch request. + :param str container_link: The container name. :param dict options: The request options for the request. :return: @@ -2042,7 +2068,10 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(resource_type, documents._OperationType.Patch, headers) + request_params = RequestObject(resource_type, + documents._OperationType.Patch, + headers, + container_link=container_link) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) request_data = {} @@ -2088,6 +2117,7 @@ def Batch( results, last_response_headers = self._Batch( formatted_operations, path, + collection_link, collection_id, options, **kwargs @@ -2124,6 +2154,7 @@ def _Batch( self, batch_operations: List[Dict[str, Any]], path: str, + container_link: str, collection_id: Optional[str], options: Mapping[str, Any], **kwargs: Any @@ -2132,7 +2163,10 @@ def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - request_params = RequestObject("docs", documents._OperationType.Batch, headers) + request_params = RequestObject("docs", + documents._OperationType.Batch, + headers, + container_link=container_link) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) return cast( @@ -2196,7 +2230,8 @@ def DeleteAllItemsByPartitionKey( http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options) request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, - headers) + headers, + container_link=collection_link) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( @@ -2627,6 +2662,7 @@ def Create( id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, + container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Creates an Azure Cosmos resource and returns it. @@ -2638,6 +2674,7 @@ def Create( :param dict initial_headers: :param dict options: The request options for the request. + :param str container_link: :return: The created Azure Cosmos resource. @@ -2654,7 +2691,7 @@ def Create( options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create, headers) + request_params = RequestObject(typ, documents._OperationType.Create, headers, container_link=container_link) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) @@ -2674,6 +2711,7 @@ def Upsert( id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, + container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Upserts an Azure Cosmos resource and returns it. @@ -2685,6 +2723,7 @@ def Upsert( :param dict initial_headers: :param dict options: The request options for the request. + :param str container_link: :return: The upserted Azure Cosmos resource. @@ -2702,7 +2741,7 @@ def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert, headers) + request_params = RequestObject(typ, documents._OperationType.Upsert, headers, container_link=container_link) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) @@ -2721,6 +2760,7 @@ def Replace( id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, + container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Replaces an Azure Cosmos resource and returns it. @@ -2732,6 +2772,7 @@ def Replace( :param dict initial_headers: :param dict options: The request options for the request. + :param str container_link: :return: The new Azure Cosmos resource. @@ -2747,7 +2788,7 @@ def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace, headers) + request_params = RequestObject(typ, documents._OperationType.Replace, headers, container_link=container_link) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index c36367ac92ac..281ddb9d2e77 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -23,7 +23,7 @@ """ from typing import TYPE_CHECKING -from azure.cosmos import PartitionKey +from azure.cosmos.partition_key import PartitionKey from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index e38164d1e193..3ca64c9575b4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -80,7 +80,6 @@ def record_failure( else EndpointOperationType.ReadType) location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) self.partition_health_tracker.add_failure(pk_range_wrapper, endpoint_operation_type, str(location)) - # TODO: @tvaron3 exponential backoff for recovering def check_stale_partition_info( self, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index d085ddb06823..c8d230086704 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -31,11 +31,13 @@ def __init__( resource_type: str, operation_type: str, headers: Dict[str, Any], + container_link: Optional[str] = None, endpoint_override: Optional[str] = None, ) -> None: self.resource_type = resource_type self.operation_type = operation_type self.endpoint_override = endpoint_override + self.container_link = container_link self.should_clear_session_token_on_session_read_failure: bool = False # pylint: disable=name-too-long self.headers = headers self.use_preferred_locations: Optional[bool] = None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 70c1eacccb89..05ef5021e6d8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -25,7 +25,7 @@ import time from typing import Optional -from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError +from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError, ServiceResponseError from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import RetryPolicy @@ -220,6 +220,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin if not database_account_retry_policy.ShouldRetry(e): raise e else: + global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) def ExecuteFunction(function, *args, **kwargs): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 665c019f4fc7..ed197428c61c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -258,6 +258,7 @@ def _set_container_properties_cache(self, container_link: str, properties: Optio :type properties: Optional[Dict[str, Any]]""" if properties: self.__container_properties_cache[container_link] = properties + self.__container_properties_cache[properties["_rid"]] = properties else: self.__container_properties_cache[container_link] = {} diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 083bc7b074b4..756a54ea8e1c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -791,7 +791,11 @@ def replace_item( # pylint:disable=docstring-missing-param self._get_properties() request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.ReplaceItem( - document_link=item_link, new_document=body, options=request_options, **kwargs) + document_link=item_link, + new_document=body, + options=request_options, + container_link=self.container_link, + **kwargs) return result @distributed_trace @@ -1049,7 +1053,10 @@ def patch_item( request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) result = self.client_connection.PatchItem( - document_link=item_link, operations=patch_operations, options=request_options, **kwargs) + document_link=item_link, + operations=patch_operations, + container_link=self.container_link, + options=request_options, **kwargs) return result @distributed_trace diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py index 447cd88996db..bd437db3b0c9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py @@ -8,6 +8,7 @@ from azure.core.exceptions import ServiceResponseError import test_config +from time import sleep from azure.cosmos import PartitionKey, _location_cache, _partition_health_tracker from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError @@ -28,7 +29,7 @@ def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - # sleep(3) + sleep(3) yield created_database.delete_container(TestPerPartitionCircuitBreakerMM.TEST_CONTAINER_SINGLE_PARTITION_ID) From 870296d0b857f420fba1ef0c318c87e690eea41e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 22 Apr 2025 23:14:59 -0700 Subject: [PATCH 101/152] revert change --- sdk/cosmos/azure-cosmos/azure/cosmos/container.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 756a54ea8e1c..4f90f4b9bc2a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -167,8 +167,6 @@ def read( # pylint:disable=docstring-missing-param request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. - :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations - in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Raised if the container couldn't be retrieved. @@ -1084,6 +1082,8 @@ def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. If all preferred locations were excluded, primary/hub location will be used. This excluded_location will override existing excluded_locations in client level. :keyword response_hook: A callable invoked with the response metadata. From 12df430a45240c0acfe9f2794dcf5d29bfa8de35 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 23 Apr 2025 11:10:14 -0700 Subject: [PATCH 102/152] add extra mapping to container cache --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 4 +-- .../_container_recreate_retry_policy.py | 5 +-- .../azure/cosmos/_cosmos_client_connection.py | 31 +++++-------------- ...tition_endpoint_manager_circuit_breaker.py | 21 +++++-------- .../azure/cosmos/_request_object.py | 3 -- .../azure/cosmos/aio/_container.py | 6 ++-- .../aio/_cosmos_client_connection_async.py | 4 +-- ..._endpoint_manager_circuit_breaker_async.py | 21 +++++-------- .../azure-cosmos/azure/cosmos/container.py | 8 ++--- .../routing/test_collection_routing_map.py | 13 -------- 10 files changed, 35 insertions(+), 81 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 7b5ac8f13dbf..ea62c78eace1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -873,8 +873,8 @@ def _format_batch_operations( return final_operations -def _set_properties_cache(properties: Dict[str, Any]) -> Dict[str, Any]: +def _set_properties_cache(properties: Dict[str, Any], container_link: str) -> Dict[str, Any]: return { "_self": properties.get("_self", None), "_rid": properties.get("_rid", None), - "partitionKey": properties.get("partitionKey", None) + "partitionKey": properties.get("partitionKey", None), "container_link": container_link, } diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py index a85e3c081ddf..78bba46319b9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py @@ -72,10 +72,7 @@ def ShouldRetry(self, exception: Optional[Any]) -> bool: def __find_container_link_with_rid(self, container_properties_caches: Optional[Dict[str, Any]], rid: str) -> \ Optional[str]: if container_properties_caches: - for key, inner_dict in container_properties_caches.items(): - is_match = next((k for k, v in inner_dict.items() if v == rid), None) - if is_match: - return key + return container_properties_caches.get(rid) # If we cannot get the container link at all it might mean the cache was somehow deleted, this isn't # a container request so this retry is not needed. Return None. return None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index fbf849fdd4f3..e9e9cf834efb 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -1992,7 +1992,6 @@ def ReplaceItem( document_link: str, new_document: Dict[str, Any], options: Optional[Mapping[str, Any]] = None, - container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Replaces a document and returns it. @@ -2002,7 +2001,6 @@ def ReplaceItem( :param dict new_document: :param dict options: The request options for the request. - :param str container_link: :return: The new Document. @@ -2034,14 +2032,12 @@ def ReplaceItem( document_id, None, options, - container_link=container_link, **kwargs) def PatchItem( self, document_link: str, operations: List[Dict[str, Any]], - container_link: str, options: Optional[Mapping[str, Any]] = None, **kwargs: Any ) -> CosmosDict: @@ -2049,7 +2045,6 @@ def PatchItem( :param str document_link: The link to the document. :param list operations: The operations for the patch request. - :param str container_link: The container name. :param dict options: The request options for the request. :return: @@ -2070,8 +2065,7 @@ def PatchItem( # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch, - headers, - container_link=container_link) + headers) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) request_data = {} @@ -2154,7 +2148,6 @@ def _Batch( self, batch_operations: List[Dict[str, Any]], path: str, - container_link: str, collection_id: Optional[str], options: Mapping[str, Any], **kwargs: Any @@ -2165,8 +2158,7 @@ def _Batch( documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch, - headers, - container_link=container_link) + headers) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) return cast( @@ -2230,8 +2222,7 @@ def DeleteAllItemsByPartitionKey( http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options) request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, - headers, - container_link=collection_link) + headers) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( @@ -2662,7 +2653,6 @@ def Create( id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, - container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Creates an Azure Cosmos resource and returns it. @@ -2674,7 +2664,6 @@ def Create( :param dict initial_headers: :param dict options: The request options for the request. - :param str container_link: :return: The created Azure Cosmos resource. @@ -2691,7 +2680,7 @@ def Create( options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create, headers, container_link=container_link) + request_params = RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) @@ -2711,7 +2700,6 @@ def Upsert( id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, - container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Upserts an Azure Cosmos resource and returns it. @@ -2723,7 +2711,6 @@ def Upsert( :param dict initial_headers: :param dict options: The request options for the request. - :param str container_link: :return: The upserted Azure Cosmos resource. @@ -2741,7 +2728,7 @@ def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert, headers, container_link=container_link) + request_params = RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) @@ -2760,7 +2747,6 @@ def Replace( id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, - container_link: Optional[str] = None, **kwargs: Any ) -> CosmosDict: """Replaces an Azure Cosmos resource and returns it. @@ -2772,7 +2758,6 @@ def Replace( :param dict initial_headers: :param dict options: The request options for the request. - :param str container_link: :return: The new Azure Cosmos resource. @@ -2788,7 +2773,7 @@ def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace, headers, container_link=container_link) + request_params = RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) @@ -3394,7 +3379,7 @@ def _refresh_container_properties_cache(self, container_link: str): # If container properties cache is stale, refresh it by reading the container. container = self.ReadContainer(container_link, options=None) # Only cache Container Properties that will not change in the lifetime of the container - self._set_container_properties_cache(container_link, _set_properties_cache(container)) + self._set_container_properties_cache(container_link, _set_properties_cache(container, container_link)) def _UpdateSessionIfRequired( self, @@ -3437,5 +3422,5 @@ def _get_partition_key_definition( else: container = self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") - self.__container_properties_cache[collection_link] = _set_properties_cache(container) + self._set_container_properties_cache(collection_link, _set_properties_cache(container, collection_link)) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 281ddb9d2e77..e7bf34ee0ce5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -54,30 +54,23 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) - target_container_link = None - partition_key = None - # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache - for container_link, properties in self.Client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], - kind=partition_key_definition["kind"]) - - if not target_container_link or not partition_key: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") + properties = self.Client._container_properties_cache[container_rid] + # get relevant information from container cache to get the overlapping ranges + container_link = properties["container_link"] + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) if request.headers.get(HttpHeaders.PartitionKey): partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] partition_ranges = (self.Client._routing_map_provider - .get_overlapping_ranges(target_container_link, epk_range)) + .get_overlapping_ranges(container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) elif request.headers.get(HttpHeaders.PartitionKeyRangeID): pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] range =(self.Client._routing_map_provider - .get_range_by_partition_key_range_id(target_container_link, pk_range_id)) + .get_range_by_partition_key_range_id(container_link, pk_range_id)) partition_range = Range.PartitionKeyRangeToRange(range) else: raise RuntimeError("Illegal state: the request does not contain partition information.") diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index c8d230086704..1ddce7880a3d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -25,19 +25,16 @@ from . import http_constants class RequestObject(object): # pylint: disable=too-many-instance-attributes - # TODO: @tvaron3 add container link here def __init__( self, resource_type: str, operation_type: str, headers: Dict[str, Any], - container_link: Optional[str] = None, endpoint_override: Optional[str] = None, ) -> None: self.resource_type = resource_type self.operation_type = operation_type self.endpoint_override = endpoint_override - self.container_link = container_link self.should_clear_session_token_on_session_read_failure: bool = False # pylint: disable=name-too-long self.headers = headers self.use_preferred_locations: Optional[bool] = None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index d85266cccb13..16d9196120b6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -94,7 +94,8 @@ def __init__( self._scripts: Optional[ScriptsProxy] = None if properties: self.client_connection._set_container_properties_cache(self.container_link, - _set_properties_cache(properties)) + _set_properties_cache(properties, + self.container_link)) def __repr__(self) -> str: return "".format(self.container_link)[:1024] @@ -192,7 +193,8 @@ async def read( request_options["populateQuotaInfo"] = populate_quota_info container = await self.client_connection.ReadContainer(self.container_link, options=request_options, **kwargs) # Only cache Container Properties that will not change in the lifetime of the container - self.client_connection._set_container_properties_cache(self.container_link, _set_properties_cache(container)) # pylint: disable=protected-access, line-too-long + self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access + _set_properties_cache(container, self.container_link)) return container @distributed_trace_async diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index ed197428c61c..a90aca304d97 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -3211,7 +3211,7 @@ async def _refresh_container_properties_cache(self, container_link: str): # If container properties cache is stale, refresh it by reading the container. container = await self.ReadContainer(container_link, options=None) # Only cache Container Properties that will not change in the lifetime of the container - self._set_container_properties_cache(container_link, _set_properties_cache(container)) + self._set_container_properties_cache(container_link, _set_properties_cache(container, container_link)) async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs) -> List[Dict[str, Any]]: supported_query_features = (documents._QueryFeature.Aggregate + "," + @@ -3308,5 +3308,5 @@ async def _get_partition_key_definition(self, collection_link: str) -> Optional[ else: container = await self.ReadContainer(collection_link) partition_key_definition = container.get("partitionKey") - self.__container_properties_cache[collection_link] = _set_properties_cache(container) + self._set_container_properties_cache(collection_link, _set_properties_cache(container, collection_link)) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 100cd672cb65..b2edae462450 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -52,30 +52,23 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) - target_container_link = None - partition_key = None - # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache - for container_link, properties in self.client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], - kind=partition_key_definition["kind"]) - - if not target_container_link or not partition_key: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") + properties = self.client._container_properties_cache[container_rid] + # get relevant information from container cache to get the overlapping ranges + container_link = properties["container_link"] + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) if request.headers.get(HttpHeaders.PartitionKey): partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] partition_ranges = await (self.client._routing_map_provider - .get_overlapping_ranges(target_container_link, epk_range)) + .get_overlapping_ranges(container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) elif request.headers.get(HttpHeaders.PartitionKeyRangeID): pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] range = await (self.client._routing_map_provider - .get_range_by_partition_key_range_id(target_container_link, pk_range_id)) + .get_range_by_partition_key_range_id(container_link, pk_range_id)) partition_range = Range.PartitionKeyRangeToRange(range) else: raise RuntimeError("Illegal state: the request does not contain partition information.") diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 4f90f4b9bc2a..c0c8bd7334b0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -92,7 +92,8 @@ def __init__( self._scripts: Optional[ScriptsProxy] = None if properties: self.client_connection._set_container_properties_cache(self.container_link, - _set_properties_cache(properties)) + _set_properties_cache(properties, + self.container_link)) def __repr__(self) -> str: return "".format(self.container_link)[:1024] @@ -198,7 +199,8 @@ def read( # pylint:disable=docstring-missing-param request_options["populateQuotaInfo"] = populate_quota_info container = self.client_connection.ReadContainer(self.container_link, options=request_options, **kwargs) # Only cache Container Properties that will not change in the lifetime of the container - self.client_connection._set_container_properties_cache(self.container_link, _set_properties_cache(container)) # pylint: disable=protected-access, line-too-long + self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access + _set_properties_cache(container, self.container_link)) return container @distributed_trace @@ -792,7 +794,6 @@ def replace_item( # pylint:disable=docstring-missing-param document_link=item_link, new_document=body, options=request_options, - container_link=self.container_link, **kwargs) return result @@ -1053,7 +1054,6 @@ def patch_item( result = self.client_connection.PatchItem( document_link=item_link, operations=patch_operations, - container_link=self.container_link, options=request_options, **kwargs) return result diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py b/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py index 69468423a111..10b5819310a9 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_collection_routing_map.py @@ -28,19 +28,6 @@ def test_advanced(self): self.assertEqual(len(overlapping_partition_key_ranges), len(partition_key_ranges)) self.assertEqual(overlapping_partition_key_ranges, partition_key_ranges) - # def test_lowercased_range(self): - # partition_key_ranges = [{u'id': u'0', u'minInclusive': u'', u'maxExclusive': u'1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF'}, - # {u'id': u'1', u'minInclusive': u'1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF', u'maxExclusive': u'FF'}] - # partitionRangeWithInfo = [(r, True) for r in partition_key_ranges] - # expected_partition_key_range = routing_range.Range("", "1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", True, False) - # - # pkRange = routing_range.Range("1EC0C2CBE45DBC919CF2B65D399C2673", "1ec0c2cbe45dbc919cf2b65d399c2674", True, True) - # collection_routing_map = CollectionRoutingMap.CompleteRoutingMap(partitionRangeWithInfo, 'sample collection id') - # overlapping_partition_key_ranges = collection_routing_map.get_overlapping_ranges(pkRange) - # - # self.assertEqual(len(overlapping_partition_key_ranges), 1) - # self.assertEqual(expected_partition_key_range, overlapping_partition_key_ranges[0]) - def test_partition_key_ranges_parent_filter(self): # for large collection with thousands of partitions, a split may complete between the read partition key ranges query pages, # causing the return map to have both the new children ranges and their ranges. This test is to verify the fix for that. From b493b368e37509de9db1be699109596afd3b11ff Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 23 Apr 2025 14:38:52 -0700 Subject: [PATCH 103/152] fix emulator tests --- .../azure-cosmos/tests/test_ppcb_mm_async.py | 10 +- .../tests/test_ppcb_sm_mrr_async.py | 394 +++++++++++++----- 2 files changed, 290 insertions(+), 114 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 418bd113469e..8b8c23d2b62b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -137,11 +137,11 @@ async def perform_write_operation(operation, container, fault_injection_containe ] resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) # this will need to be emulator only - # elif operation == "delete_all_items_by_partition_key": - # await container.create_item(body=doc) - # await container.create_item(body=doc) - # await container.create_item(body=doc) - # resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + elif operation == DELETE_ALL_ITEMS_BY_PARTITION_KEY: + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.create_item(body=doc) + resp = await fault_injection_container.delete_all_items_by_partition_key(pk) if resp: validate_response_uri(resp, expected_uri) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index ec6cf5ecff7d..3be437d79ada 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -12,26 +12,26 @@ from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey +from azure.cosmos import PartitionKey, _partition_health_tracker from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport_async import FaultInjectionTransportAsync +from tests.test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ + cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations COLLECTION = "created_collection" @pytest_asyncio.fixture() -async def setup(): +async def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" - client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key, consistency_level="Session") + client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) - await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) - created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), offer_throughput=10000) - yield { - COLLECTION: created_collection - } - + # allow some time for the container to be created as this method is in different event loop + await asyncio.sleep(3) + yield await created_database.delete_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" @@ -54,6 +54,24 @@ def operations_and_errors(): return params +def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + + else: + assert health_info.read_consecutive_failure_count < 10 + # single region write account should never track write failures + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + + assert unhealthy_partitions == expected_unhealthy_partitions + @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -78,42 +96,6 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] await method_client.close() - @staticmethod - async def perform_write_operation(operation, container, doc_id, pk): - doc = {'id': doc_id, - 'pk': pk, - 'name': 'sample document', - 'key': 'value'} - if operation == "create": - await container.create_item(body=doc) - elif operation == "upsert": - await container.upsert_item(body=doc) - elif operation == "replace": - new_doc = {'id': doc_id, - 'pk': pk, - 'name': 'sample document' + str(uuid), - 'key': 'value'} - await container.replace_item(item=doc['id'], body=new_doc) - elif operation == "delete": - await container.create_item(body=doc) - await container.delete_item(item=doc['id'], partition_key=doc['pk']) - elif operation == "patch": - operations = [{"op": "incr", "path": "/company", "value": 3}] - await container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) - elif operation == "batch": - batch_operations = [ - ("create", (doc, )), - ("upsert", (doc,)), - ("upsert", (doc,)), - ("upsert", (doc,)), - ] - await container.execute_item_batch(batch_operations, partition_key=doc['pk']) - elif operation == "delete_all_items_by_partition_key": - await container.create_item(body=doc) - await container.create_item(body=doc) - await container.create_item(body=doc) - await container.delete_all_items_by_partition_key(pk) - @staticmethod async def perform_read_operation(operation, container, doc_id, pk, expected_read_region_uri): if operation == "read": @@ -134,10 +116,6 @@ async def perform_read_operation(operation, container, doc_id, pk, expected_read async for item in container.read_all_items(partition_key=pk): assert item['pk'] == pk - - - - async def create_custom_transport_sm_mrr(self): custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region @@ -161,62 +139,100 @@ async def create_custom_transport_sm_mrr(self): emulator_as_multi_region_sm_account_transformation) return custom_transport - - # split this into write and read tests - - @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) - async def test_consecutive_failure_threshold_async(self, setup, write_operation, read_operation, error): - expected_read_region_uri = self.host - expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") + async def setup_info(self, error): + expected_uri = self.host + uri_down = expected_uri.replace("localhost", "127.0.0.1") custom_transport = await self.create_custom_transport_sm_mrr() - id_value = 'failoverDoc-' + str(uuid.uuid4()) - document_definition = {'id': id_value, - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) - custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error - ))) - - custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - container = custom_setup['col'] + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): + validate_unhealthy_partitions(global_endpoint_manager, 0) with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, - container, - document_definition['id'], - document_definition['pk']) - global_endpoint_manager = container.client_connection._global_endpoint_manager + await perform_write_operation( + write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) + + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - # create item with client without fault injection - await setup[COLLECTION].create_item(body=document_definition) + # create a document to read + await container.create_item(body=doc) # reads should fail over and only the relevant partition should be marked as unavailable - await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, - container, - document_definition['id'], - document_definition['pk'], - expected_read_region_uri) + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) # partition should not have been marked unavailable after one error - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(10): - await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, - container, - document_definition['id'], - document_definition['pk'], - expected_read_region_uri) - - # the partition should have been marked as unavailable after breaking read threshold - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) - # test recovering the partition again + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + # remove faults and reduce initial recover time and perform a read + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) async def test_failure_rate_threshold_async(self, setup, write_operation, read_operation, error): @@ -280,31 +296,191 @@ async def test_failure_rate_threshold_async(self, setup, write_operation, read_o document_definition['pk'], expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions(global_endpoint_manager, 1) finally: # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - # look at the urls for verifying fall back and use another id for same partition - - @staticmethod - def validate_unhealthy_partitions(global_endpoint_manager, - expected_unhealthy_partitions): - health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info - unhealthy_partitions = 0 - for pk_range_wrapper, location_to_health_info in health_info_map.items(): - for location, health_info in location_to_health_info.items(): - health_status = health_info.unavailability_info.get(HEALTH_STATUS) - if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: - unhealthy_partitions += 1 - assert health_info.write_failure_count == 0 - assert health_info.write_consecutive_failure_count == 0 + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_failure_rate_threshold_async(self, setup_teardown, write_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + await fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) else: - assert health_info.read_consecutive_failure_count < 10 - assert health_info.write_failure_count == 0 - assert health_info.write_consecutive_failure_count == 0 + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_failure_rate_threshold_async(self, setup_teardown, read_operation, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + await container.upsert_item(body=doc) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + if isinstance(error, ServiceResponseError): + # service response error retries in region 3 additional times before failing over + num_operations = 2 + else: + num_operations = 8 + for i in range(num_operations): + validate_unhealthy_partitions(global_endpoint_manager, 0) + # read will fail and retry in other region + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 + + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) + + + @pytest.mark.parametrize("read_operation, write_operation", operations()) + async def test_service_request_error_async(self, read_operation, write_operation): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down()) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + await container.upsert_item(body=doc) + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + + # recover partition + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + + await perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + await cleanup_method([custom_setup, setup]) + + + + # send 5 write concurrent requests when trying to recover + # verify that only one failed + async def test_recovering_only_fails_one_requests_async(self): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, CosmosHttpResponseError( + status_code=502, + message="Some envoy error."))) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + for i in range(5): + with pytest.raises(CosmosHttpResponseError): + await fault_injection_container.create_item(body=doc) + + + number_of_errors = 0 + + async def concurrent_upsert(): + nonlocal number_of_errors + doc = {'id': str(uuid.uuid4()), + 'pk': PK_VALUE, + 'name': 'sample document', + 'key': 'value'} + try: + await fault_injection_container.upsert_item(doc) + except CosmosHttpResponseError as e: + number_of_errors += 1 + + # attempt to recover partition + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + try: + tasks = [] + for i in range(10): + tasks.append(concurrent_upsert()) + await asyncio.gather(*tasks) + assert number_of_errors == 1 + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + await cleanup_method([custom_setup]) - assert unhealthy_partitions == expected_unhealthy_partitions # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again # test service request marks only a partition unavailable not an entire region - across operation types From 998aaa752a58a13981b1398376ae17f3202fa4a9 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 23 Apr 2025 15:11:57 -0700 Subject: [PATCH 104/152] add sync single master tests --- .../azure-cosmos/tests/test_ppcb_sm_mrr.py | 309 ++++++++++++++++++ .../tests/test_ppcb_sm_mrr_async.py | 108 ------ 2 files changed, 309 insertions(+), 108 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py new file mode 100644 index 000000000000..be435358d782 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py @@ -0,0 +1,309 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +import pytest_asyncio +from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey, _partition_health_tracker +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport_async import FaultInjectionTransportAsync +from tests.test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ + cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations +from tests.test_ppcb_sm_mrr_async import validate_unhealthy_partitions + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPerPartitionCircuitBreakerSmMrr.host, TestPerPartitionCircuitBreakerSmMrr.master_key) + created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrr.TEST_DATABASE_ID) + created_database.create_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + # allow some time for the container to be created as this method is in different event loop + asyncio.sleep(3) + yield + created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) + client.close() + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +@pytest.mark.cosmosEmulator +@pytest.mark.usefixtures("setup") +class TestPerPartitionCircuitBreakerSmMrr: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + method_client.close() + + + def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransportAsync() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransportAsync.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + def setup_info(self, error): + expected_uri = self.host + uri_down = expected_uri.replace("localhost", "127.0.0.1") + custom_transport = self.create_custom_transport_sm_mrr() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_consecutive_failure_threshold(self, setup_teardown, write_operation, error): + error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + validate_unhealthy_partitions(global_endpoint_manager, 0) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + perform_write_operation( + write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions(global_endpoint_manager, 0) + cleanup_method([custom_setup, setup]) + + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation, error): + error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # create a document to read + container.create_item(body=doc) + + # reads should fail over and only the relevant partition should be marked as unavailable + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + # partition should not have been marked unavailable after one error + validate_unhealthy_partitions(global_endpoint_manager, 0) + + for i in range(10): + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + # remove faults and reduce initial recover time and perform a read + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + doc['pk'], + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + cleanup_method([custom_setup, setup]) + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + def test_write_failure_rate_threshold(self, setup_teardown, write_operation, error): + error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + cleanup_method([custom_setup, setup]) + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error): + error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + container.upsert_item(body=doc) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + if isinstance(error, ServiceResponseError): + # service response error retries in region 3 additional times before failing over + num_operations = 2 + else: + num_operations = 8 + for i in range(num_operations): + validate_unhealthy_partitions(global_endpoint_manager, 0) + # read will fail and retry in other region + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): + # these operations are cross partition so they would mark both partitions as unavailable + expected_unhealthy_partitions = 2 + else: + expected_unhealthy_partitions = 1 + + validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + cleanup_method([custom_setup, setup]) + + + @pytest.mark.parametrize("read_operation, write_operation", operations()) + def test_service_request_error(self, read_operation, write_operation): + # the region should be tried 4 times before failing over and mark the partition as unavailable + # the region should not be marked as unavailable + error_lambda = lambda r: FaultInjectionTransportAsync.error_region_down() + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + container.upsert_item(body=doc) + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + + # recover partition + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_read_operation(read_operation, + fault_injection_container, + doc['id'], + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransportAsync.error_region_down()) + + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + validate_unhealthy_partitions(global_endpoint_manager, 0) + # there shouldn't be region marked as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 + cleanup_method([custom_setup, setup]) + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 3be437d79ada..e272ec01f397 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -36,24 +36,6 @@ async def setup_teardown(): await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" -def operations_and_errors(): - write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] - read_operations = ["read", "query", "changefeed", "read_all_items", "delete_all_items_by_partition_key"] - errors = [] - error_codes = [408, 500, 502, 503] - for error_code in error_codes: - errors.append(CosmosHttpResponseError( - status_code=error_code, - message="Some injected error.")) - errors.append(ServiceResponseError(message="Injected Service Response Error.")) - params = [] - for write_operation in write_operations: - for read_operation in read_operations: - for error in errors: - params.append((write_operation, read_operation, error)) - - return params - def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info @@ -96,26 +78,6 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] await method_client.close() - @staticmethod - async def perform_read_operation(operation, container, doc_id, pk, expected_read_region_uri): - if operation == "read": - read_resp = await container.read_item(item=doc_id, partition_key=pk) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_read_region_uri) - elif operation == "query": - query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" - parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] - async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): - assert item['id'] == doc_id - # need to do query with no pk and with feed range - elif operation == "changefeed": - async for _ in container.query_items_change_feed(): - pass - elif operation == "read_all_items": - async for item in container.read_all_items(partition_key=pk): - assert item['pk'] == pk - async def create_custom_transport_sm_mrr(self): custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region @@ -234,73 +196,6 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re validate_unhealthy_partitions(global_endpoint_manager, 0) await cleanup_method([custom_setup, setup]) - @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) - async def test_failure_rate_threshold_async(self, setup, write_operation, read_operation, error): - expected_read_region_uri = self.host - expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") - custom_transport = await self.create_custom_transport_sm_mrr() - id_value = 'failoverDoc-' + str(uuid.uuid4()) - # two documents targeted to same partition, one will always fail and the other will succeed - document_definition = {'id': id_value, - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} - doc_2 = {'id': str(uuid.uuid4()), - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} - predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and - FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) - custom_transport.add_fault(predicate, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 0, - error - ))) - - custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - container = custom_setup['col'] - global_endpoint_manager = container.client_connection._global_endpoint_manager - # lower minimum requests for testing - global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 - try: - # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition - for i in range(14): - if i == 9: - await container.upsert_item(body=doc_2) - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, - container, - document_definition['id'], - document_definition['pk']) - - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - - # create item with client without fault injection - await setup[COLLECTION].create_item(body=document_definition) - - # reads should fail over and only the relevant partition should be marked as unavailable - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) - # partition should not have been marked unavailable after one error - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - for i in range(20): - if i == 8: - read_resp = await container.read_item(item=doc_2['id'], - partition_key=doc_2['pk']) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_read_region_uri) - else: - await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, - container, - document_definition['id'], - document_definition['pk'], - expected_read_region_uri) - # the partition should have been marked as unavailable after breaking read threshold - validate_unhealthy_partitions(global_endpoint_manager, 1) - finally: - # restore minimum requests - global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_write_failure_rate_threshold_async(self, setup_teardown, write_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( @@ -481,9 +376,6 @@ async def concurrent_upsert(): _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time await cleanup_method([custom_setup]) - - # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again - # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout if __name__ == '__main__': From 64b02bd924862b80c61ea640461959ea89d4c07b Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 23 Apr 2025 15:12:49 -0700 Subject: [PATCH 105/152] fix tests --- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 3be437d79ada..d629f7186eac 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -75,7 +75,7 @@ def validate_unhealthy_partitions(global_endpoint_manager, @pytest.mark.cosmosEmulator @pytest.mark.asyncio -@pytest.mark.usefixtures("setup") +@pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrrAsync: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey From d78838e857108f26060e31c424f8c43ecda80010 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 23 Apr 2025 15:15:06 -0700 Subject: [PATCH 106/152] fix tests --- .../azure-cosmos/tests/test_ppcb_sm_mrr.py | 47 +++++++------------ 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py index be435358d782..1467fd207951 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py @@ -4,7 +4,6 @@ import os import unittest import uuid -from typing import Dict, Any import pytest import pytest_asyncio @@ -15,9 +14,9 @@ from azure.cosmos import PartitionKey, _partition_health_tracker from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError -from _fault_injection_transport_async import FaultInjectionTransportAsync +from tests._fault_injection_transport import FaultInjectionTransport from tests.test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ - cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations + read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations from tests.test_ppcb_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" @@ -33,11 +32,10 @@ def setup_teardown(): asyncio.sleep(3) yield created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) - client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" @pytest.mark.cosmosEmulator -@pytest.mark.usefixtures("setup") +@pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrr: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey @@ -53,27 +51,21 @@ def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) return {"client": client, "db": db, "col": container} - @staticmethod - def cleanup_method(initialized_objects: Dict[str, Any]): - method_client: CosmosClient = initialized_objects["client"] - method_client.close() - - def create_custom_transport_sm_mrr(self): - custom_transport = FaultInjectionTransportAsync() + custom_transport = FaultInjectionTransport() # Inject rule to disallow writes in the read-only region is_write_operation_in_read_region_predicate = lambda \ - r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) + r: FaultInjectionTransport.predicate_is_write_operation(r, self.host) custom_transport.add_fault( is_write_operation_in_read_region_predicate, - lambda r: FaultInjectionTransportAsync.error_write_forbidden()) + lambda r: FaultInjectionTransport.error_write_forbidden()) # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) emulator_as_multi_region_sm_account_transformation = \ - lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", inner=inner) @@ -88,8 +80,8 @@ def setup_info(self, error): custom_transport = self.create_custom_transport_sm_mrr() # two documents targeted to same partition, one will always fail and the other will succeed doc = create_doc() - predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and - FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, error) custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) @@ -98,7 +90,7 @@ def setup_info(self, error): @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) def test_write_consecutive_failure_threshold(self, setup_teardown, write_operation, error): - error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) container = setup['col'] fault_injection_container = custom_setup['col'] @@ -118,12 +110,11 @@ def test_write_consecutive_failure_threshold(self, setup_teardown, write_operati ) validate_unhealthy_partitions(global_endpoint_manager, 0) - cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation, error): - error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) container = setup['col'] fault_injection_container = custom_setup['col'] @@ -169,11 +160,10 @@ def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) - cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) def test_write_failure_rate_threshold(self, setup_teardown, write_operation, error): - error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) container = setup['col'] fault_injection_container = custom_setup['col'] @@ -191,7 +181,7 @@ def test_write_failure_rate_threshold(self, setup_teardown, write_operation, err custom_transport.faults = [] fault_injection_container.upsert_item(body=doc) custom_transport.add_fault(predicate, - lambda r: FaultInjectionTransportAsync.error_after_delay( + lambda r: FaultInjectionTransport.error_after_delay( 0, error )) @@ -211,11 +201,10 @@ def test_write_failure_rate_threshold(self, setup_teardown, write_operation, err os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error): - error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay(0, error) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) container = setup['col'] fault_injection_container = custom_setup['col'] @@ -249,14 +238,13 @@ def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("read_operation, write_operation", operations()) def test_service_request_error(self, read_operation, write_operation): # the region should be tried 4 times before failing over and mark the partition as unavailable # the region should not be marked as unavailable - error_lambda = lambda r: FaultInjectionTransportAsync.error_region_down() + error_lambda = lambda r: FaultInjectionTransport.error_region_down() setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) container = setup['col'] fault_injection_container = custom_setup['col'] @@ -288,7 +276,7 @@ def test_service_request_error(self, read_operation, write_operation): validate_unhealthy_partitions(global_endpoint_manager, 0) custom_transport.add_fault(predicate, - lambda r: FaultInjectionTransportAsync.error_region_down()) + lambda r: FaultInjectionTransport.error_region_down()) perform_write_operation(write_operation, container, @@ -301,7 +289,6 @@ def test_service_request_error(self, read_operation, write_operation): validate_unhealthy_partitions(global_endpoint_manager, 0) # there shouldn't be region marked as unavailable assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 - cleanup_method([custom_setup, setup]) # test cosmos client timeout From 2ffddf1126e00ccada4351bce2797be75ff1f27d Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 23 Apr 2025 15:32:56 -0700 Subject: [PATCH 107/152] fix tests --- .../azure-cosmos/azure/cosmos/_cosmos_client_connection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index e9e9cf834efb..3dc2ab1e2c78 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -1301,7 +1301,6 @@ def CreateItem( collection_id, None, options, - container_link=database_or_container_link, **kwargs) def UpsertItem( @@ -1344,7 +1343,6 @@ def UpsertItem( collection_id, None, options, - container_link=database_or_container_link, **kwargs) PartitionResolverErrorMessage = ( From f5acc273aecc3e476383c193a945754e9d663e66 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 23 Apr 2025 15:34:34 -0700 Subject: [PATCH 108/152] fix tests --- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py index 1467fd207951..faa286ec5b2e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py @@ -1,9 +1,9 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import asyncio import os import unittest import uuid +from time import sleep import pytest import pytest_asyncio @@ -12,11 +12,12 @@ import test_config from azure.cosmos import PartitionKey, _partition_health_tracker -from azure.cosmos.aio import CosmosClient +from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from tests._fault_injection_transport import FaultInjectionTransport -from tests.test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ - read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations +from tests.test_ppcb_mm import perform_write_operation, perform_read_operation +from tests.test_ppcb_mm_async import create_doc, PK_VALUE, write_operations_and_errors, \ + read_operations_and_errors, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations from tests.test_ppcb_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" @@ -29,7 +30,7 @@ def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - asyncio.sleep(3) + sleep(3) yield created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" From b7521fd026b3daab22237015ff5892fdbad3ab7f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 23 Apr 2025 16:27:33 -0700 Subject: [PATCH 109/152] fix tests --- .../azure/cosmos/_cosmos_client_connection.py | 1 - ...lobal_partition_endpoint_manager_circuit_breaker.py | 10 +++++----- .../azure/cosmos/_routing/routing_map_provider.py | 9 +++++---- ...partition_endpoint_manager_circuit_breaker_async.py | 4 ++-- sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py | 4 ++-- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py | 4 ++-- .../azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 2 +- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3dc2ab1e2c78..8c64563e0533 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2109,7 +2109,6 @@ def Batch( results, last_response_headers = self._Batch( formatted_operations, path, - collection_link, collection_id, options, **kwargs diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index e7bf34ee0ce5..87748b0bacdf 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -54,7 +54,7 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) - properties = self.Client._container_properties_cache[container_rid] + properties = self.Client._container_properties_cache[container_rid] # pylint: disable=protected-access # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] partition_key_definition = properties["partitionKey"] @@ -63,15 +63,15 @@ def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWr if request.headers.get(HttpHeaders.PartitionKey): partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key - epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] - partition_ranges = (self.Client._routing_map_provider + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] # pylint: disable=protected-access + partition_ranges = (self.Client._routing_map_provider # pylint: disable=protected-access .get_overlapping_ranges(container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) elif request.headers.get(HttpHeaders.PartitionKeyRangeID): pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] - range =(self.Client._routing_map_provider + epk_range =(self.Client._routing_map_provider # pylint: disable=protected-access .get_range_by_partition_key_range_id(container_link, pk_range_id)) - partition_range = Range.PartitionKeyRangeToRange(range) + partition_range = Range.PartitionKeyRangeToRange(epk_range) else: raise RuntimeError("Illegal state: the request does not contain partition information.") diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 99ce35de3bff..57ea23c1c646 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -51,7 +51,7 @@ def __init__(self, client): # keeps the cached collection routing map by collection id self._collection_routing_map_by_item = {} - def initialize_collection_routing_map_if_needed( + def init_collection_routing_map_if_needed( self, collection_link: str, collection_id: str, @@ -80,7 +80,7 @@ def get_overlapping_ranges(self, collection_link, partition_key_ranges, **kwargs :rtype: list """ collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) @@ -91,9 +91,10 @@ def get_range_by_partition_key_range_id( **kwargs: Dict[str, Any] ) -> Dict[str, Any]: collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) - return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id(partition_key_range_id) + return (self._collection_routing_map_by_item[collection_id] + .get_range_by_partition_key_range_id(partition_key_range_id)) @staticmethod def _discard_parent_ranges(partitionKeyRanges): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index b2edae462450..0140c04ab033 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -67,9 +67,9 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyR partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) elif request.headers.get(HttpHeaders.PartitionKeyRangeID): pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] - range = await (self.client._routing_map_provider + epk_range = await (self.client._routing_map_provider .get_range_by_partition_key_range_id(container_link, pk_range_id)) - partition_range = Range.PartitionKeyRangeToRange(range) + partition_range = Range.PartitionKeyRangeToRange(epk_range) else: raise RuntimeError("Illegal state: the request does not contain partition information.") diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py index bd437db3b0c9..469293dd83d4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py @@ -12,8 +12,8 @@ from azure.cosmos import PartitionKey, _location_cache, _partition_health_tracker from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError -from tests._fault_injection_transport import FaultInjectionTransport -from tests.test_ppcb_mm_async import DELETE, CREATE, UPSERT, REPLACE, PATCH, BATCH, validate_response_uri, READ, \ +from _fault_injection_transport import FaultInjectionTransport +from test_ppcb_mm_async import DELETE, CREATE, UPSERT, REPLACE, PATCH, BATCH, validate_response_uri, READ, \ QUERY_PK, QUERY, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS, REGION_1, REGION_2, \ write_operations_and_errors, validate_unhealthy_partitions, read_operations_and_errors, PK_VALUE, operations, \ create_doc diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py index be435358d782..70d12f81cdf8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py @@ -16,9 +16,9 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport_async import FaultInjectionTransportAsync -from tests.test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ +from test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations -from tests.test_ppcb_sm_mrr_async import validate_unhealthy_partitions +from test_ppcb_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" @pytest_asyncio.fixture() diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 5c1009fed25f..0102a1cb7d99 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -17,7 +17,7 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport_async import FaultInjectionTransportAsync -from tests.test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ +from test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations COLLECTION = "created_collection" From d1e553e1db205739adf3a18dce344bb0647791f6 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 23 Apr 2025 17:15:30 -0700 Subject: [PATCH 110/152] test changes --- sdk/cosmos/azure-cosmos/pytest.ini | 2 ++ sdk/cosmos/azure-cosmos/tests/test_crud.py | 2 +- sdk/cosmos/live-platform-matrix.json | 20 ++++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 0ea65741e343..08e4868261f8 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -4,3 +4,5 @@ markers = cosmosLong: marks tests to be run on a Cosmos DB live account. cosmosQuery: marks tests running queries on Cosmos DB live account. cosmosSplit: marks test where there are partition splits on CosmosDB live account. + cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region enabled. + cosmosPPCB: marks tests running on Cosmos DB live account with per partition circuit breaker enabled. diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index 90fe5a22eaea..51b81ce5bcad 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -47,7 +47,7 @@ def send(self, *args, **kwargs): response = RequestsTransportResponse(None, output) return response - +@pytest.mark.cosmosPPCB @pytest.mark.cosmosLong class TestCRUDOperations(unittest.TestCase): """Python CRUD Tests. diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 347c166f869c..66f15f713fa6 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -25,6 +25,26 @@ } } }, + { + "PPCBTestConfig": { + "Ubuntu2004_39_ppcb": { + "OSVmImage": "env:LINUXVMIMAGE", + "Pool": "env:LINUXPOOL", + "PythonVersion": "3.9", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosPPCB" + }, + "Ubuntu2004_313_ppcb": { + "OSVmImage": "env:LINUXVMIMAGE", + "Pool": "env:LINUXPOOL", + "PythonVersion": "3.13", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosPPCB" + } + } + }, { "MacTestConfig": { "macos311_search_query": { From 454a3ce9e26e16a8f34177960ccb93529afc2383 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 23 Apr 2025 23:01:59 -0700 Subject: [PATCH 111/152] fix some tests --- .../azure/cosmos/_routing/routing_map_provider.py | 2 +- .../tests/routing/test_routing_map_provider.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_feed_range.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_location_cache.py | 10 +++++----- sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 57ea23c1c646..b22f05acb6b0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -59,7 +59,7 @@ def init_collection_routing_map_if_needed( ): client = self._documentClient collection_routing_map = self._collection_routing_map_by_item.get(collection_id) - if collection_routing_map is None: + if not collection_routing_map: collection_pk_ranges = list(client._ReadPartitionKeyRanges(collection_link, **kwargs)) # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py index cc69d630c162..8e4633b3ccd7 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py @@ -181,7 +181,7 @@ def validate_empty_query_ranges(self, smart_routing_map_provider, *queryRangesLi self.validate_overlapping_ranges_results(queryRanges, []) def get_overlapping_ranges(self, queryRanges): - return self.smart_routing_map_provider.get_overlapping_ranges("sample collection id", queryRanges) + return self.smart_routing_map_provider.get_overlapping_ranges("dbs/db/colls/container", queryRanges) if __name__ == "__main__": diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range.py index f44c39e034a9..fe292547d505 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range.py @@ -41,7 +41,7 @@ def setup(): True), (Range("3F", "7F", False, True), Range("3F", "7F", True, True), - False), + True), (Range("3F", "7F", True, False), Range("3F", "7F", True, True), False), diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index 3b804ee37a3c..ad8e021ab577 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -215,9 +215,9 @@ def test_get_applicable_regional_endpoints_excluded_regions(self, test_type): location_cache.perform_on_database_account_read(database_account) # Init requests and set excluded regions on requests - write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) write_doc_request.excluded_locations = excluded_locations_on_requests - read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read, None) read_doc_request.excluded_locations = excluded_locations_on_requests # Test if read endpoints were correctly filtered on client level @@ -247,7 +247,7 @@ def test_set_excluded_locations_for_requests(self): options: Mapping[str, Any] = {"excludedLocations": excluded_locations} expected_excluded_locations = excluded_locations - read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) read_doc_request.set_excluded_location_from_options(options) actual_excluded_locations = read_doc_request.excluded_locations assert actual_excluded_locations == expected_excluded_locations @@ -262,7 +262,7 @@ def test_set_excluded_locations_for_requests(self): expected_excluded_locations = None for resource_type in [ResourceType.Offer, ResourceType.Conflict]: options: Mapping[str, Any] = {"excludedLocations": [location1_name]} - read_doc_request = RequestObject(resource_type, _OperationType.Create) + read_doc_request = RequestObject(resource_type, _OperationType.Create, None) read_doc_request.set_excluded_location_from_options(options) actual_excluded_locations = read_doc_request.excluded_locations assert actual_excluded_locations == expected_excluded_locations @@ -279,7 +279,7 @@ def test_set_excluded_locations_for_requests(self): "If you want to remove all excluded locations, try passing an empty list.") with pytest.raises(ValueError) as e: options: Mapping[str, Any] = {"excludedLocations": None} - doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) doc_request.set_excluded_location_from_options(options) assert str( e.value) == expected_error_message diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py index 8b8c23d2b62b..b5d312c28a41 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -19,7 +19,7 @@ from _fault_injection_transport_async import FaultInjectionTransportAsync REGION_1 = "West US 3" -REGION_2 = "Mexico Central" # "West US" +REGION_2 = "West US" CHANGE_FEED = "changefeed" CHANGE_FEED_PK = "changefeed_pk" CHANGE_FEED_EPK = "changefeed_epk" From 629e07bfd03ce39df9d72362b869211f0ff84ccd Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 24 Apr 2025 21:35:32 -0700 Subject: [PATCH 112/152] fix tests --- sdk/cosmos/azure-cosmos/tests/test_location_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index ad8e021ab577..be3fc7140a50 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -126,8 +126,8 @@ def test_resolve_request_endpoint_preferred_regions(self): lc = refresh_location_cache([location1_name, location3_name, location4_name], True) db_acc = create_database_account(True) lc.perform_on_database_account_read(db_acc) - write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) - read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read, None) # resolve both document requests with all regions available write_doc_resolved = lc.resolve_service_endpoint(write_doc_request) From dd5034d78985e12008066a318f0bf9cb2ca3377a Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 24 Apr 2025 22:56:15 -0700 Subject: [PATCH 113/152] Rename test files --- ..._ppcb_mm.py => test_per_partition_circuit_breaker_mm.py} | 2 +- ...nc.py => test_per_partition_circuit_breaker_mm_async.py} | 0 ..._mrr.py => test_per_partition_circuit_breaker_sm_mrr.py} | 6 +++--- ...y => test_per_partition_circuit_breaker_sm_mrr_async.py} | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) rename sdk/cosmos/azure-cosmos/tests/{test_ppcb_mm.py => test_per_partition_circuit_breaker_mm.py} (99%) rename sdk/cosmos/azure-cosmos/tests/{test_ppcb_mm_async.py => test_per_partition_circuit_breaker_mm_async.py} (100%) rename sdk/cosmos/azure-cosmos/tests/{test_ppcb_sm_mrr.py => test_per_partition_circuit_breaker_sm_mrr.py} (98%) rename sdk/cosmos/azure-cosmos/tests/{test_ppcb_sm_mrr_async.py => test_per_partition_circuit_breaker_sm_mrr_async.py} (99%) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py similarity index 99% rename from sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py rename to sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 469293dd83d4..de5c45fa3513 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -13,7 +13,7 @@ from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport import FaultInjectionTransport -from test_ppcb_mm_async import DELETE, CREATE, UPSERT, REPLACE, PATCH, BATCH, validate_response_uri, READ, \ +from test_per_partition_circuit_breaker_mm_async import DELETE, CREATE, UPSERT, REPLACE, PATCH, BATCH, validate_response_uri, READ, \ QUERY_PK, QUERY, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS, REGION_1, REGION_2, \ write_operations_and_errors, validate_unhealthy_partitions, read_operations_and_errors, PK_VALUE, operations, \ create_doc diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py similarity index 100% rename from sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py rename to sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py similarity index 98% rename from sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py rename to sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index 4f84b9c27a46..f3fa4dfde167 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -15,10 +15,10 @@ from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport import FaultInjectionTransport -from test_ppcb_mm import perform_write_operation, perform_read_operation -from test_ppcb_mm_async import create_doc, PK_VALUE, write_operations_and_errors, \ +from test_per_partition_circuit_breaker_mm import perform_write_operation, perform_read_operation +from test_per_partition_circuit_breaker_mm_async import create_doc, PK_VALUE, write_operations_and_errors, \ read_operations_and_errors, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations -from test_ppcb_sm_mrr_async import validate_unhealthy_partitions +from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" @pytest_asyncio.fixture() diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py similarity index 99% rename from sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py rename to sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index 8220c950ce9f..a7002e62ab7b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -17,7 +17,7 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport_async import FaultInjectionTransportAsync -from test_ppcb_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ +from test_per_partition_circuit_breaker_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations COLLECTION = "created_collection" From 5d942d2917717fac95b0d8ec7b2a928bb6c821e1 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 24 Apr 2025 23:39:51 -0700 Subject: [PATCH 114/152] fix tests and setup ppcb pipeline --- sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py | 3 ++- .../azure-cosmos/azure/cosmos/aio/_retry_utility_async.py | 3 ++- sdk/cosmos/azure-cosmos/pytest.ini | 2 +- sdk/cosmos/azure-cosmos/tests/test_crud_async.py | 4 +++- .../tests/test_per_partition_circuit_breaker_mm.py | 2 +- .../tests/test_per_partition_circuit_breaker_mm_async.py | 2 +- sdk/cosmos/live-platform-matrix.json | 5 +++++ sdk/cosmos/test-resources.bicep | 5 ++++- 8 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 7e4019cb2607..8c7174d8cb08 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -353,7 +353,8 @@ def send(self, request): raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: global_endpoint_manager.record_failure(request_params) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 8c59644780ca..2ed4de959757 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -324,7 +324,8 @@ async def send(self, request): raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 08e4868261f8..a12da2004f38 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -4,5 +4,5 @@ markers = cosmosLong: marks tests to be run on a Cosmos DB live account. cosmosQuery: marks tests running queries on Cosmos DB live account. cosmosSplit: marks test where there are partition splits on CosmosDB live account. - cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region enabled. + cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled. cosmosPPCB: marks tests running on Cosmos DB live account with per partition circuit breaker enabled. diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index ca6cfad8287d..19f9f3c89ac4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -4,6 +4,7 @@ """End-to-end test. """ +import os import time import unittest import urllib.parse as urllib @@ -42,7 +43,7 @@ async def send(self, *args, **kwargs): response = AsyncioRequestsTransportResponse(None, output) return response - +@pytest.mark.cosmosPPCB @pytest.mark.cosmosLong class TestCRUDOperationsAsync(unittest.IsolatedAsyncioTestCase): """Python CRUD Tests. @@ -70,6 +71,7 @@ async def __assert_http_failure_with_status(self, status_code, func, *args, **kw @classmethod def setUpClass(cls): + print("Circuit Breaker enabled: " + os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "True")) if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index de5c45fa3513..c957d14af530 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -109,7 +109,7 @@ def perform_read_operation(operation, container, doc_id, pk, expected_uri): for _ in container.read_all_items(): pass -@pytest.mark.cosmosMultiRegion +@pytest.mark.cosmosPPCB @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMM: host = test_config.TestConfig.host diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index b5d312c28a41..59cc93c697f2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -199,7 +199,7 @@ async def cleanup_method(initialized_objects: List[Dict[str, Any]]): method_client: CosmosClient = obj["client"] await method_client.close() -@pytest.mark.cosmosMultiRegion +@pytest.mark.cosmosPPCB @pytest.mark.asyncio @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMMAsync: diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 66f15f713fa6..546a6de8543d 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -43,6 +43,11 @@ "TestSamples": "false", "TestMarkArgument": "cosmosPPCB" } + }, + "ArmConfig": { + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" + } } }, { diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 88abe955f8d8..9ad6aaa1a365 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -12,6 +12,9 @@ param enableMultipleRegions bool = false @description('Location for the Cosmos DB account.') param location string = resourceGroup().location +@description('Whether Per Partition Circuit Breaker should be enabled.') +param ppcbEnabled bool = false + @description('The api version to be used by Bicep to create resources') param apiVersion string = '2023-04-15' @@ -101,6 +104,6 @@ resource accountName_roleAssignmentId 'Microsoft.DocumentDB/databaseAccounts/sql } } - +output PPCB_ENABLED bool = ppcbEnabled output ACCOUNT_HOST string = reference(resourceId, apiVersion).documentEndpoint output ACCOUNT_KEY string = listKeys(resourceId, apiVersion).primaryMasterKey From ef52e43fecb7b7f1364bf7a42d11a1f9e60a1ede Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Fri, 25 Apr 2025 11:16:49 -0700 Subject: [PATCH 115/152] fix tests --- sdk/cosmos/azure-cosmos/tests/test_crud_async.py | 7 +++++-- sdk/cosmos/test-resources.bicep | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 19f9f3c89ac4..856ce4359560 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -71,7 +71,6 @@ async def __assert_http_failure_with_status(self, status_code, func, *args, **kw @classmethod def setUpClass(cls): - print("Circuit Breaker enabled: " + os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "True")) if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( @@ -80,7 +79,11 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - self.client = CosmosClient(self.host, self.masterKey) + print("Circuit Breaker enabled: " + os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "True")) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + use_multiple_write_locations = True + self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=use_multiple_write_locations) self.database_for_test = self.client.get_database_client(self.configs.TEST_DATABASE_ID) async def asyncTearDown(self): diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 9ad6aaa1a365..5824c936bb8f 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -104,6 +104,6 @@ resource accountName_roleAssignmentId 'Microsoft.DocumentDB/databaseAccounts/sql } } -output PPCB_ENABLED bool = ppcbEnabled +output AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER bool = ppcbEnabled output ACCOUNT_HOST string = reference(resourceId, apiVersion).documentEndpoint output ACCOUNT_KEY string = listKeys(resourceId, apiVersion).primaryMasterKey From 218cf2a63c5899542541fcd5fd8912aeed918941 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Sat, 26 Apr 2025 22:49:46 -0700 Subject: [PATCH 116/152] fix ci tests --- ...n_endpoint_manager_circuit_breaker_core.py | 6 +- .../azure/cosmos/_location_cache.py | 5 +- .../aio/_cosmos_client_connection_async.py | 7 +- sdk/cosmos/azure-cosmos/pytest.ini | 3 +- sdk/cosmos/azure-cosmos/tests/test_config.py | 108 ++++++++++++------ sdk/cosmos/azure-cosmos/tests/test_crud.py | 2 +- .../azure-cosmos/tests/test_crud_async.py | 2 +- .../test_per_partition_circuit_breaker_mm.py | 2 +- ..._per_partition_circuit_breaker_mm_async.py | 4 +- ..._partition_circuit_breaker_sm_mrr_async.py | 2 +- .../tests/test_retry_policy_async.py | 2 +- 11 files changed, 98 insertions(+), 45 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 3ca64c9575b4..2c6af49d4b9e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -30,7 +30,7 @@ from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import EndpointOperationType, LocationCache from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType +from azure.cosmos.http_constants import ResourceType, HttpMethods, HttpHeaders from azure.cosmos._constants import _Constants as Constants @@ -67,6 +67,10 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: if request.operation_type == documents._OperationType.QueryPlan: # pylint: disable=protected-access return False + # this is for certain cross partition queries and read all items where we cannot discern partition information + if not request.headers.get(HttpHeaders.PartitionKeyRangeID) and not request.headers.get(HttpHeaders.PartitionKey): + return False + return True def record_failure( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index ab70259cfb4e..0b0e8c32bd90 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -204,9 +204,12 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> List[str # If excluded locations were configured on request, use request level excluded locations. excluded_locations = request.excluded_locations if excluded_locations is None: + if self.connection_policy.ExcludedLocations: + excluded_locations = [] # If excluded locations were only configured on client(connection_policy), use client level # make copy of excluded locations to avoid modifying the original list - excluded_locations = list(self.connection_policy.ExcludedLocations) + else: + excluded_locations = list(self.connection_policy.ExcludedLocations) excluded_locations.extend(request.excluded_locations_circuit_breaker) return excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 8e58374c9afc..c756074d5450 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2878,6 +2878,10 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: return [] initial_headers = self.default_headers.copy() + cont_prop = kwargs.pop("containerProperties", None) + if cont_prop: + cont_prop = await cont_prop() + options["containerRID"] = cont_prop["_rid"] @@ -2922,10 +2926,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: partition_key = options.get("partitionKey", None) isPrefixPartitionQuery = False partition_key_definition = None - cont_prop = kwargs.pop("containerProperties", None) if cont_prop and partition_key: - cont_prop = await cont_prop() - options["containerRID"] = cont_prop["_rid"] pk_properties = cont_prop["partitionKey"] partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"]) if partition_key_definition.kind == "MultiHash" and \ diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index a12da2004f38..35bfa6b6e494 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -5,4 +5,5 @@ markers = cosmosQuery: marks tests running queries on Cosmos DB live account. cosmosSplit: marks test where there are partition splits on CosmosDB live account. cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled. - cosmosPPCB: marks tests running on Cosmos DB live account with per partition circuit breaker enabled. + cosmosPPCBMultiWrite: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled. + cosmosPPCBMultiRegion: marks tests running on Cosmos DB live account with per partition circuit breaker enabled with multiple regions. diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index 4ba1e0cc0a07..6bd41796d7cc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -7,15 +7,14 @@ import unittest import uuid -from azure.cosmos._retry_utility import _has_database_account_header, _has_read_retryable_headers +from azure.cosmos._retry_utility import _has_database_account_header, _has_read_retryable_headers, _configure_timeout from azure.cosmos.cosmos_client import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import StatusCodes from azure.cosmos.partition_key import PartitionKey from azure.cosmos import (ContainerProxy, DatabaseProxy, documents, exceptions, http_constants, _retry_utility) -from azure.cosmos.aio import _retry_utility_async -from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError +from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError, ClientAuthenticationError from azure.core.pipeline.policies import AsyncRetryPolicy, RetryPolicy from devtools_testutils.azure_recorded_testcase import get_credential from devtools_testutils.helpers import is_live @@ -302,7 +301,10 @@ def __init__(self, resource_type, error=None, **kwargs): def send(self, request): self.counter = 0 absolute_timeout = request.context.options.pop('timeout', None) - + per_request_timeout = request.context.options.pop('connection_timeout', 0) + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) + retry_error = None retry_active = True response = None retry_settings = self.configure_retries(request.context.options) @@ -314,26 +316,44 @@ def send(self, request): self.request_endpoints.append(request.http_request.url) if self.error: raise self.error + _configure_timeout(request, absolute_timeout, per_request_timeout) response = self.next.send(request) break + except ClientAuthenticationError: # pylint:disable=try-except-raise + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise + except exceptions.CosmosClientTimeoutError as timeout_error: + timeout_error.inner_exception = retry_error + timeout_error.response = response + timeout_error.history = retry_settings['history'] + raise except ServiceRequestError as err: + retry_error = err # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies # This logic is based on the _retry.py file from azure-core - if retry_settings['connect'] > 0: - self.counter += 1 - retry_active = self.increment(retry_settings, response=request, error=err) - if retry_active: - self.sleep(retry_settings, request.context.transport) - continue + if (not _has_database_account_header(request.http_request.headers) + and not request_params.healthy_tentative_location): + if retry_settings['connect'] > 0: + self.counter += 1 + global_endpoint_manager.record_failure(request_params) + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + self.sleep(retry_settings, request.context.transport) + continue raise err except ServiceResponseError as err: + retry_error = err # Only read operations can be safely retried with ServiceResponseError - if not _retry_utility._has_read_retryable_headers(request.http_request.headers): + if (not _has_read_retryable_headers(request.http_request.headers) or + _has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: self.counter += 1 + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) @@ -343,10 +363,12 @@ def send(self, request): raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: self.counter += 1 + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) @@ -382,9 +404,11 @@ async def send(self, request): :raises ~azure.cosmos.exceptions.CosmosClientTimeoutError: Specified timeout exceeded. :raises ~azure.core.exceptions.ClientAuthenticationError: Authentication failed. """ + self.counter = 0 absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - self.counter = 0 + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -392,15 +416,18 @@ async def send(self, request): while retry_active: start_time = time.time() try: - # raise the passed in exception for the passed in resource + operation combination if request.http_request.headers.get( http_constants.HttpHeaders.ThinClientProxyResourceType) == self.resource_type: self.request_endpoints.append(request.http_request.url) if self.error: raise self.error - _retry_utility._configure_timeout(request, absolute_timeout, per_request_timeout) + _configure_timeout(request, absolute_timeout, per_request_timeout) response = await self.next.send(request) break + except ClientAuthenticationError: # pylint:disable=try-except-raise + # the authentication policy failed such that the client's request can't + # succeed--we'll never have a response to it, so propagate the exception + raise except exceptions.CosmosClientTimeoutError as timeout_error: timeout_error.inner_exception = retry_error timeout_error.response = response @@ -410,40 +437,57 @@ async def send(self, request): retry_error = err # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies - if retry_settings['connect'] > 0: - self.counter += 1 - retry_active = self.increment(retry_settings, response=request, error=err) - if retry_active: - await self.sleep(retry_settings, request.context.transport) - continue - raise err - except ServiceResponseError as err: - retry_error = err - # Since this is ClientConnectionError, it is safe to be retried on both read and write requests - from aiohttp.client_exceptions import ( - ClientConnectionError) # pylint: disable=networking-import-outside-azure-core-transport - if isinstance(err.inner_exception, ClientConnectionError) or _retry_utility_async._has_read_retryable_headers(request.http_request.headers): - # This logic is based on the _retry.py file from azure-core - if retry_settings['read'] > 0: + if (not _has_database_account_header(request.http_request.headers) + and not request_params.healthy_tentative_location): + if retry_settings['connect'] > 0: self.counter += 1 + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) continue raise err + except ServiceResponseError as err: + retry_error = err + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): + raise err + # Since this is ClientConnectionError, it is safe to be retried on both read and write requests + try: + # pylint: disable=networking-import-outside-azure-core-transport + from aiohttp.client_exceptions import ( + ClientConnectionError) + if (isinstance(err.inner_exception, ClientConnectionError) + or _has_read_retryable_headers(request.http_request.headers)): + # This logic is based on the _retry.py file from azure-core + if retry_settings['read'] > 0: + self.counter += 1 + await global_endpoint_manager.record_failure(request_params) + retry_active = self.increment(retry_settings, response=request, error=err) + if retry_active: + await self.sleep(retry_settings, request.context.transport) + continue + except ImportError: + raise err # pylint: disable=raise-missing-from + raise err except CosmosHttpResponseError as err: raise err except AzureError as err: retry_error = err - if _has_database_account_header(request.http_request.headers): + if (_has_database_account_header(request.http_request.headers) or + request_params.healthy_tentative_location): raise err if _has_read_retryable_headers(request.http_request.headers) and retry_settings['read'] > 0: - retry_active = self.increment(retry_settings, response=request, error=err) self.counter += 1 + retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) continue raise err + finally: + end_time = time.time() + if absolute_timeout: + absolute_timeout -= (end_time - start_time) self.update_context(response.context, retry_settings) return response diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index 51b81ce5bcad..ae90850e5a98 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -47,7 +47,7 @@ def send(self, *args, **kwargs): response = RequestsTransportResponse(None, output) return response -@pytest.mark.cosmosPPCB +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.cosmosLong class TestCRUDOperations(unittest.TestCase): """Python CRUD Tests. diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 856ce4359560..8b20b20b293c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -43,7 +43,7 @@ async def send(self, *args, **kwargs): response = AsyncioRequestsTransportResponse(None, output) return response -@pytest.mark.cosmosPPCB +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.cosmosLong class TestCRUDOperationsAsync(unittest.IsolatedAsyncioTestCase): """Python CRUD Tests. diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index c957d14af530..1299e18566f8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -109,7 +109,7 @@ def perform_read_operation(operation, container, doc_id, pk, expected_uri): for _ in container.read_all_items(): pass -@pytest.mark.cosmosPPCB +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMM: host = test_config.TestConfig.host diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 59cc93c697f2..8f614b58dd3e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -77,7 +77,7 @@ def create_doc(): 'key': 'value'} def read_operations_and_errors(): - read_operations = [READ, QUERY, QUERY_PK, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS] + read_operations = [READ, QUERY_PK, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK] errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -199,7 +199,7 @@ async def cleanup_method(initialized_objects: List[Dict[str, Any]]): method_client: CosmosClient = obj["client"] await method_client.close() -@pytest.mark.cosmosPPCB +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.asyncio @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMMAsync: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index a7002e62ab7b..f7501485a162 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -55,7 +55,7 @@ def validate_unhealthy_partitions(global_endpoint_manager, assert unhealthy_partitions == expected_unhealthy_partitions -@pytest.mark.cosmosEmulator +@pytest.mark.cosmosPPCBMultiRegion @pytest.mark.asyncio @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrrAsync: diff --git a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py index dd9870a99231..fccf19225cfe 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py @@ -46,7 +46,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] - self.ExcludedLocations = None + self.ExcludedLocations = [] self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False From 7952bba433f29c4a3c88b14bcec94e468e70d408 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Sun, 27 Apr 2025 01:24:05 -0700 Subject: [PATCH 117/152] move all ppcb tests to live tests --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 2 +- .../_container_recreate_retry_policy.py | 3 +- ...n_endpoint_manager_circuit_breaker_core.py | 11 +++--- .../azure/cosmos/_location_cache.py | 4 +- .../tests/test_excluded_locations.py | 8 ++-- ...st_per_partition_circuit_breaker_sm_mrr.py | 38 ++++--------------- ..._partition_circuit_breaker_sm_mrr_async.py | 14 ++++--- sdk/cosmos/live-platform-matrix.json | 27 ++++++++----- 8 files changed, 48 insertions(+), 59 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index af3813938237..af77a499ffa2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -880,5 +880,5 @@ def _format_batch_operations( def _set_properties_cache(properties: Dict[str, Any], container_link: str) -> Dict[str, Any]: return { "_self": properties.get("_self", None), "_rid": properties.get("_rid", None), - "partitionKey": properties.get("partitionKey", None), "container_link": container_link, + "partitionKey": properties.get("partitionKey", None), "container_link": container_link } diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py index 78bba46319b9..4080d3272d85 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py @@ -72,7 +72,8 @@ def ShouldRetry(self, exception: Optional[Any]) -> bool: def __find_container_link_with_rid(self, container_properties_caches: Optional[Dict[str, Any]], rid: str) -> \ Optional[str]: if container_properties_caches: - return container_properties_caches.get(rid) + if container_properties_caches.get(rid): + return container_properties_caches[rid]["container_link"] # If we cannot get the container link at all it might mean the cache was somehow deleted, this isn't # a container request so this retry is not needed. Return None. return None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 2c6af49d4b9e..6e23efabbed3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -30,7 +30,7 @@ from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import EndpointOperationType, LocationCache from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType, HttpMethods, HttpHeaders +from azure.cosmos.http_constants import ResourceType, HttpHeaders from azure.cosmos._constants import _Constants as Constants @@ -61,14 +61,13 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: and documents._OperationType.IsWriteOperation(request.operation_type)): # pylint: disable=protected-access return False - if request.resource_type != ResourceType.Document: - return False - - if request.operation_type == documents._OperationType.QueryPlan: # pylint: disable=protected-access + if (request.resource_type != ResourceType.Document + or request.operation_type == documents._OperationType.QueryPlan): # pylint: disable=protected-access return False # this is for certain cross partition queries and read all items where we cannot discern partition information - if not request.headers.get(HttpHeaders.PartitionKeyRangeID) and not request.headers.get(HttpHeaders.PartitionKey): + if (not request.headers.get(HttpHeaders.PartitionKeyRangeID) + and not request.headers.get(HttpHeaders.PartitionKey)): return False return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 0b0e8c32bd90..2c5d1ff5e3c5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -210,7 +210,9 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> List[str # make copy of excluded locations to avoid modifying the original list else: excluded_locations = list(self.connection_policy.ExcludedLocations) - excluded_locations.extend(request.excluded_locations_circuit_breaker) + for excluded_location in request.excluded_locations_circuit_breaker: + if excluded_location not in excluded_locations: + excluded_locations.append(excluded_location) return excluded_locations def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) -> List[RegionalRoutingContext]: diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 4b517796dd3a..c1b72b3f71a3 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -210,18 +210,18 @@ def _verify_endpoint(messages, client, expected_locations): req_urls = [url.replace("Request URL: '", "") for url in messages if 'Request URL:' in url] # get location - actual_locations = [] + actual_locations = set() for req_url in req_urls: if req_url.startswith(default_endpoint): - actual_locations.append(L0) + actual_locations.add(L0) else: for endpoint in location_mapping: if req_url.startswith(endpoint): location = location_mapping[endpoint] - actual_locations.append(location) + actual_locations.add(location) break - assert actual_locations == expected_locations + assert list(actual_locations) == expected_locations @pytest.mark.cosmosMultiRegion class TestExcludedLocations: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index f3fa4dfde167..8ed0cffa334a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -7,17 +7,16 @@ import pytest import pytest_asyncio -from azure.core.pipeline.transport._aiohttp import AioHttpTransport from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _partition_health_tracker +from azure.cosmos import PartitionKey, _partition_health_tracker, _location_cache from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport import FaultInjectionTransport from test_per_partition_circuit_breaker_mm import perform_write_operation, perform_read_operation from test_per_partition_circuit_breaker_mm_async import create_doc, PK_VALUE, write_operations_and_errors, \ - read_operations_and_errors, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations + read_operations_and_errors, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations, REGION_2, REGION_1 from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" @@ -35,7 +34,7 @@ def setup_teardown(): created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" -@pytest.mark.cosmosEmulator +@pytest.mark.cosmosPPCMMultiRegion @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrr: host = test_config.TestConfig.host @@ -44,7 +43,7 @@ class TestPerPartitionCircuitBreakerSmMrr: TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) - def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", preferred_locations=["Write Region", "Read Region"], transport=custom_transport, **kwargs) @@ -52,33 +51,10 @@ def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) return {"client": client, "db": db, "col": container} - def create_custom_transport_sm_mrr(self): - custom_transport = FaultInjectionTransport() - # Inject rule to disallow writes in the read-only region - is_write_operation_in_read_region_predicate = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, self.host) - - custom_transport.add_fault( - is_write_operation_in_read_region_predicate, - lambda r: FaultInjectionTransport.error_write_forbidden()) - - # Inject topology transformation that would make Emulator look like a single write region - # account with two read regions - is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation = \ - lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( - write_region_name="Write Region", - read_region_name="Read Region", - inner=inner) - custom_transport.add_response_transformation( - is_get_account_predicate, - emulator_as_multi_region_sm_account_transformation) - return custom_transport - def setup_info(self, error): - expected_uri = self.host - uri_down = expected_uri.replace("localhost", "127.0.0.1") - custom_transport = self.create_custom_transport_sm_mrr() + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransport() # two documents targeted to same partition, one will always fail and the other will succeed doc = create_doc() predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index f7501485a162..eefe22c72a68 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -12,13 +12,15 @@ from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _partition_health_tracker +from azure.cosmos import PartitionKey, _partition_health_tracker, _location_cache from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport_async import FaultInjectionTransportAsync -from test_per_partition_circuit_breaker_mm_async import perform_write_operation, create_doc, PK_VALUE, write_operations_and_errors, \ - cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations +from test_per_partition_circuit_breaker_mm_async import perform_write_operation, create_doc, PK_VALUE, \ + write_operations_and_errors, \ + cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations, \ + REGION_2, REGION_1 COLLECTION = "created_collection" @pytest_asyncio.fixture() @@ -102,9 +104,9 @@ async def create_custom_transport_sm_mrr(self): return custom_transport async def setup_info(self, error): - expected_uri = self.host - uri_down = expected_uri.replace("localhost", "127.0.0.1") - custom_transport = await self.create_custom_transport_sm_mrr() + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + custom_transport = FaultInjectionTransportAsync() # two documents targeted to same partition, one will always fail and the other will succeed doc = create_doc() predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 546a6de8543d..19aec8a30f39 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -26,27 +26,36 @@ } }, { - "PPCBTestConfig": { - "Ubuntu2004_39_ppcb": { + "PPCBMultiWriteTestConfig": { + "Ubuntu2004_313_ppcb": { "OSVmImage": "env:LINUXVMIMAGE", "Pool": "env:LINUXPOOL", - "PythonVersion": "3.9", + "PythonVersion": "3.13", "CoverageArg": "--disablecov", "TestSamples": "false", - "TestMarkArgument": "cosmosPPCB" - }, - "Ubuntu2004_313_ppcb": { + "TestMarkArgument": "cosmosPPCBMultiWrite" + } + }, + "ArmConfig": { + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" + } + } + }, + { + "PPCBMultiRegionTestConfig": { + "Ubuntu2004_39_ppcb": { "OSVmImage": "env:LINUXVMIMAGE", "Pool": "env:LINUXPOOL", - "PythonVersion": "3.13", + "PythonVersion": "3.9", "CoverageArg": "--disablecov", "TestSamples": "false", - "TestMarkArgument": "cosmosPPCB" + "TestMarkArgument": "cosmosPPCBMultiRegion" } }, "ArmConfig": { "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" + "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" } } }, From 1b41c737aead17db14340f7dab3de5deda641ea3 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Sun, 27 Apr 2025 11:44:16 -0700 Subject: [PATCH 118/152] add logger for ppcb test --- sdk/cosmos/azure-cosmos/tests/test_crud_async.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 8b20b20b293c..3cb5b026b198 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -4,6 +4,7 @@ """End-to-end test. """ +import logging import os import time import unittest @@ -79,7 +80,8 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - print("Circuit Breaker enabled: " + os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "True")) + logger = logging.getLogger("TestCRUDOperationsAsync") + logger.info("Circuit Breaker enabled: " + os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "True")) use_multiple_write_locations = False if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": use_multiple_write_locations = True From 28a4cc438b9b9385da7fa563f3ebaebe729d8221 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Sun, 27 Apr 2025 13:52:58 -0700 Subject: [PATCH 119/152] fix ci tests --- .../azure/cosmos/_global_endpoint_manager.py | 1 - .../azure/cosmos/_location_cache.py | 8 ++--- .../azure/cosmos/_retry_utility.py | 2 +- .../azure-cosmos/tests/test_globaldb.py | 3 +- .../azure-cosmos/tests/test_globaldb_mock.py | 30 +++++++------------ 5 files changed, 18 insertions(+), 26 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 8f561fd35b0d..f61bd37feba7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -47,7 +47,6 @@ class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attrib def __init__(self, client): self.Client = client - self.EnableEndpointDiscovery = client.connection_policy.EnableEndpointDiscovery self.PreferredLocations = client.connection_policy.PreferredLocations self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 2c5d1ff5e3c5..2aef0d6b9656 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -205,11 +205,11 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> List[str excluded_locations = request.excluded_locations if excluded_locations is None: if self.connection_policy.ExcludedLocations: - excluded_locations = [] - # If excluded locations were only configured on client(connection_policy), use client level - # make copy of excluded locations to avoid modifying the original list - else: + # If excluded locations were only configured on client(connection_policy), use client level + # make copy of excluded locations to avoid modifying the original list excluded_locations = list(self.connection_policy.ExcludedLocations) + else: + excluded_locations = [] for excluded_location in request.excluded_locations_circuit_breaker: if excluded_location not in excluded_locations: excluded_locations.append(excluded_location) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 8c7174d8cb08..87c18c4d44b3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -64,7 +64,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( - client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args + client.connection_policy, global_endpoint_manager, *args ) database_account_retry_policy = _database_account_retry_policy.DatabaseAccountRetryPolicy( client.connection_policy diff --git a/sdk/cosmos/azure-cosmos/tests/test_globaldb.py b/sdk/cosmos/azure-cosmos/tests/test_globaldb.py index 3fc83ef68f0d..e700ab5e9b1c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_globaldb.py +++ b/sdk/cosmos/azure-cosmos/tests/test_globaldb.py @@ -429,7 +429,8 @@ def test_global_db_service_request_errors(self): cosmos_client.CosmosClient(self.host, self.masterKey, connection_retry_policy=mock_retry_policy) pytest.fail("Exception was not raised") except ServiceRequestError: - assert mock_retry_policy.counter == 3 + # Database account calls should not be retried in connection retry policy + assert mock_retry_policy.counter == 0 def test_global_db_endpoint_discovery_retry_policy_mock(self): client = cosmos_client.CosmosClient(self.host, self.masterKey) diff --git a/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py b/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py index 9b2a221880b1..8c1a2fbfb01e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py +++ b/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py @@ -1,7 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import json import unittest import pytest @@ -13,13 +12,16 @@ import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import _retry_utility +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker import \ + _GlobalPartitionEndpointManagerForCircuitBreaker from azure.cosmos.http_constants import StatusCodes location_changed = False -class MockGlobalEndpointManager: +class MockGlobalEndpointManager(_GlobalPartitionEndpointManagerForCircuitBreaker): def __init__(self, client): + super(MockGlobalEndpointManager, self).__init__(client) self.Client = client self.DefaultEndpoint = client.url_connection self._ReadEndpoint = client.url_connection @@ -73,10 +75,10 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self._ReadEndpoint - def resolve_service_endpoint(self, request): + def resolve_service_endpoint(self, request, pk_range_wrapper): return - def refresh_endpoint_list(self): + def refresh_endpoint_list(self, database_account, **kwargs): return def can_use_multiple_write_locations(self, request): @@ -150,20 +152,6 @@ def tearDown(self): global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.OriginalGetDatabaseAccountStub _retry_utility.ExecuteFunction = self.OriginalExecuteFunction - def MockExecuteFunction(self, function, *args, **kwargs): - global location_changed - - if self.endpoint_discovery_retry_count == 2: - _retry_utility.ExecuteFunction = self.OriginalExecuteFunction - return json.dumps([{'id': 'mock database'}]), None - else: - self.endpoint_discovery_retry_count += 1 - location_changed = True - raise exceptions.CosmosHttpResponseError( - status_code=StatusCodes.FORBIDDEN, - message="Forbidden", - response=test_config.FakeResponse({'x-ms-substatus': 3})) - def MockGetDatabaseAccountStub(self, endpoint): raise exceptions.CosmosHttpResponseError( status_code=StatusCodes.INTERNAL_SERVER_ERROR, message="Internal Server Error") @@ -176,6 +164,8 @@ def test_global_db_endpoint_discovery_retry_policy(self): TestGlobalDBMock.masterKey, consistency_level="Session", connection_policy=connection_policy) + write_location_client.client_connection._global_endpoint_manager = MockGlobalEndpointManager(write_location_client.client_connection) + write_location_client.client_connection._global_endpoint_manager.refresh_endpoint_list(None) self.assertEqual(write_location_client.client_connection.WriteEndpoint, TestGlobalDBMock.write_location_host) @@ -188,6 +178,8 @@ def test_global_db_database_account_unavailable(self): client = cosmos_client.CosmosClient(TestGlobalDBMock.host, TestGlobalDBMock.masterKey, consistency_level="Session", connection_policy=connection_policy) + client.client_connection._global_endpoint_manager = MockGlobalEndpointManager(client.client_connection) + client.client_connection._global_endpoint_manager.refresh_endpoint_list(None) self.assertEqual(client.client_connection.WriteEndpoint, TestGlobalDBMock.write_location_host) self.assertEqual(client.client_connection.ReadEndpoint, TestGlobalDBMock.write_location_host) @@ -195,7 +187,7 @@ def test_global_db_database_account_unavailable(self): global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub client.client_connection.DatabaseAccountAvailable = False - client.client_connection._global_endpoint_manager.refresh_endpoint_list() + client.client_connection._global_endpoint_manager.refresh_endpoint_list(None) self.assertEqual(client.client_connection.WriteEndpoint, TestGlobalDBMock.host) self.assertEqual(client.client_connection.ReadEndpoint, TestGlobalDBMock.host) From 0096d84a695c800a553b307d5649a51c2759f284 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Sun, 27 Apr 2025 15:00:02 -0700 Subject: [PATCH 120/152] fix tests --- .../azure/cosmos/_retry_utility.py | 14 ++++++---- .../azure/cosmos/aio/_retry_utility_async.py | 12 ++++++--- sdk/cosmos/azure-cosmos/pytest.ini | 1 - .../azure-cosmos/tests/test_change_feed.py | 4 +++ .../tests/test_change_feed_async.py | 5 +++- sdk/cosmos/azure-cosmos/tests/test_crud.py | 5 +++- .../azure-cosmos/tests/test_crud_async.py | 2 -- .../tests/test_excluded_locations.py | 2 +- ...st_per_partition_circuit_breaker_sm_mrr.py | 2 +- ..._partition_circuit_breaker_sm_mrr_async.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_query.py | 6 ++++- .../azure-cosmos/tests/test_query_async.py | 6 ++++- sdk/cosmos/live-platform-matrix.json | 27 +++++++------------ 13 files changed, 51 insertions(+), 37 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 8c7174d8cb08..45399d8fcb5e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -178,8 +178,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: - # record the failure for circuit breaker tracking - global_endpoint_manager.record_failure(args[0]) + if args: + # record the failure for circuit breaker tracking + global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -205,7 +206,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - global_endpoint_manager.record_failure(args[0]) + if args: + global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -213,7 +215,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin if not database_account_retry_policy.ShouldRetry(e): raise e else: - global_endpoint_manager.record_failure(args[0]) + if args: + global_endpoint_manager.record_failure(args[0]) _handle_service_request_retries(client, service_request_retry_policy, e, *args) except ServiceResponseError as e: @@ -221,7 +224,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin if not database_account_retry_policy.ShouldRetry(e): raise e else: - global_endpoint_manager.record_failure(args[0]) + if args: + global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) def ExecuteFunction(function, *args, **kwargs): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 2ed4de959757..b337d409bce1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -177,7 +177,8 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg request.headers[retry_policy._intended_headers] = retry_policy.container_rid elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: # record the failure for circuit breaker tracking - await global_endpoint_manager.record_failure(args[0]) + if args: + await global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -203,7 +204,8 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - await global_endpoint_manager.record_failure(args[0]) + if args: + await global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -225,11 +227,13 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if isinstance(e.inner_exception, ClientConnectionError): _handle_service_request_retries(client, service_request_retry_policy, e, *args) else: - await global_endpoint_manager.record_failure(args[0]) + if args: + await global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) # in case customer is not using aiohttp except ImportError: - await global_endpoint_manager.record_failure(args[0]) + if args: + await global_endpoint_manager.record_failure(args[0]) _handle_service_response_retries(request, client, service_response_retry_policy, e, *args) diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 35bfa6b6e494..0db8e9cd12eb 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -6,4 +6,3 @@ markers = cosmosSplit: marks test where there are partition splits on CosmosDB live account. cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled. cosmosPPCBMultiWrite: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled. - cosmosPPCBMultiRegion: marks tests running on Cosmos DB live account with per partition circuit breaker enabled with multiple regions. diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py index 415ba47a63a9..32e39ee1e15f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py @@ -17,6 +17,9 @@ @pytest.fixture(scope="class") def setup(): config = test_config.TestConfig() + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + use_multiple_write_locations = True if (config.masterKey == '[YOUR_KEY_HERE]' or config.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( @@ -33,6 +36,7 @@ def round_time(): utc_now = datetime.now(timezone.utc) return utc_now - timedelta(microseconds=utc_now.microsecond) +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.cosmosQuery @pytest.mark.unittest @pytest.mark.usefixtures("setup") diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py index b3a666ad43e9..24cf945f5b62 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py @@ -18,13 +18,16 @@ @pytest_asyncio.fixture() async def setup(): + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + use_multiple_write_locations = True config = test_config.TestConfig() if config.masterKey == '[YOUR_KEY_HERE]' or config.host == '[YOUR_ENDPOINT_HERE]': raise Exception( "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - test_client = CosmosClient(config.host, config.masterKey) + test_client = CosmosClient(config.host, config.masterKey, multiple_write_locations=use_multiple_write_locations) created_db = await test_client.create_database_if_not_exists(config.TEST_DATABASE_ID) created_db_data = { "created_db": created_db, diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index ae90850e5a98..282071276210 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -75,13 +75,16 @@ def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): @classmethod def setUpClass(cls): + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + use_multiple_write_locations = True if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey) + cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, multiple_write_locations=use_multiple_write_locations) cls.databaseForTest = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID) def test_partitioned_collection_document_crud_and_query(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 3cb5b026b198..7b0b4476befd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -80,8 +80,6 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - logger = logging.getLogger("TestCRUDOperationsAsync") - logger.info("Circuit Breaker enabled: " + os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "True")) use_multiple_write_locations = False if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": use_multiple_write_locations = True diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index c1b72b3f71a3..b3288e2b2fd9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -221,7 +221,7 @@ def _verify_endpoint(messages, client, expected_locations): actual_locations.add(location) break - assert list(actual_locations) == expected_locations + assert list(actual_locations) == list(set(expected_locations)) @pytest.mark.cosmosMultiRegion class TestExcludedLocations: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index 8ed0cffa334a..a1e7bbe19f1d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -34,7 +34,7 @@ def setup_teardown(): created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" -@pytest.mark.cosmosPPCMMultiRegion +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrr: host = test_config.TestConfig.host diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index eefe22c72a68..c945c4d2d9c0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -57,7 +57,7 @@ def validate_unhealthy_partitions(global_endpoint_manager, assert unhealthy_partitions == expected_unhealthy_partitions -@pytest.mark.cosmosPPCBMultiRegion +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.asyncio @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrrAsync: diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 2a99263ed457..8a17e76b4f6d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -17,6 +17,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.cosmosQuery class TestQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -32,7 +33,10 @@ class TestQuery(unittest.TestCase): @classmethod def setUpClass(cls): - cls.client = cosmos_client.CosmosClient(cls.host, cls.credential) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + use_multiple_write_locations = True + cls.client = cosmos_client.CosmosClient(cls.host, cls.credential, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) if cls.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index 19e01d8149b4..e2ffc708dbaa 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -18,6 +18,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey +@pytest.mark.cosmosPPCBMultiWrite @pytest.mark.cosmosQuery class TestQueryAsync(unittest.IsolatedAsyncioTestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -34,6 +35,9 @@ class TestQueryAsync(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): + cls.use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + cls.use_multiple_write_locations = True if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( @@ -42,7 +46,7 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - self.client = CosmosClient(self.host, self.masterKey) + self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=self.use_multiple_write_locations) self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) if self.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 19aec8a30f39..58e90fb4bd8c 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -27,35 +27,26 @@ }, { "PPCBMultiWriteTestConfig": { - "Ubuntu2004_313_ppcb": { + "Ubuntu2004_39_ppcb": { "OSVmImage": "env:LINUXVMIMAGE", "Pool": "env:LINUXPOOL", - "PythonVersion": "3.13", + "PythonVersion": "3.9", "CoverageArg": "--disablecov", "TestSamples": "false", - "TestMarkArgument": "cosmosPPCBMultiWrite" - } - }, - "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" - } - } - }, - { - "PPCBMultiRegionTestConfig": { - "Ubuntu2004_39_ppcb": { + "TestMarkArgument": "cosmosPPCBMultiRegion" + }, + "Ubuntu2004_313_ppcb": { "OSVmImage": "env:LINUXVMIMAGE", "Pool": "env:LINUXPOOL", - "PythonVersion": "3.9", + "PythonVersion": "3.13", "CoverageArg": "--disablecov", "TestSamples": "false", - "TestMarkArgument": "cosmosPPCBMultiRegion" + "TestMarkArgument": "cosmosPPCBMultiWrite" } }, "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" + "MultiMaster": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" } } }, From 153903eaab0415a7b2ebd6d8e78c27d3df7bf0f3 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 28 Apr 2025 16:19:17 -0700 Subject: [PATCH 121/152] fixed resource token bug and tests --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/auth.py | 10 ++++++---- sdk/cosmos/azure-cosmos/tests/test_change_feed.py | 2 +- .../azure-cosmos/tests/test_change_feed_async.py | 1 + sdk/cosmos/azure-cosmos/tests/test_crud.py | 6 ++---- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 9290f48478ed..07fc1079dba0 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -10,6 +10,7 @@ #### Breaking Changes #### Bugs Fixed +* Fixed how resource tokens are parsed for metadata calls in the lifecycle of a document operation. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). * Fixed bug where change feed requests would not respect the partition key filter. See [PR 40677](https://github.com/Azure/azure-sdk-for-python/pull/40677). * Fixed how the environment variables in the sdk are parsed. See [PR 40303](https://github.com/Azure/azure-sdk-for-python/pull/40303). * Fixed health check to check the first write region when it is not specified in the preferred regions. See [PR 40588](https://github.com/Azure/azure-sdk-for-python/pull/40588). diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py index 754a9c93cdf3..497007f2dbc1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py @@ -124,8 +124,8 @@ def __get_authorization_token_using_resource_token(resource_tokens, path, resour # used for creating the auth header as the service will accept any token in this case path = urllib.parse.unquote(path) if not path and not resource_id_or_fullname: - for value in resource_tokens.values(): - return value + for resource_token in resource_tokens.values(): + return resource_token if resource_tokens.get(resource_id_or_fullname): return resource_tokens[resource_id_or_fullname] @@ -151,7 +151,9 @@ def __get_authorization_token_using_resource_token(resource_tokens, path, resour for i in range(len(path_parts), 1, -1): segment = path_parts[i - 1] sub_path = "/".join(path_parts[:i]) - if not segment in resource_types and sub_path in resource_tokens: - return resource_tokens[sub_path] + if not segment in resource_types: + for path, resource_token in resource_tokens.items(): + if sub_path in path: + return resource_tokens[path] return None diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py index 32e39ee1e15f..5c93e56be041 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import unittest import uuid from datetime import datetime, timedelta, timezone diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py index 24cf945f5b62..17b128927ef7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py @@ -3,6 +3,7 @@ import unittest import uuid +import os from asyncio import sleep from datetime import datetime, timedelta, timezone diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index 282071276210..c9f8487a0198 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -5,12 +5,11 @@ """End-to-end test. """ -import json -import os.path import time import unittest import urllib.parse as urllib import uuid +import os import pytest import requests @@ -19,7 +18,6 @@ from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse from urllib3.util.retry import Retry -import azure.cosmos._base as base import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.documents as documents import azure.cosmos.exceptions as exceptions @@ -1146,7 +1144,7 @@ def test_client_request_timeout(self): container = databaseForTest.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) - async def test_read_timeout_async(self): + def test_read_timeout_async(self): connection_policy = documents.ConnectionPolicy() # making timeout 0 ms to make sure it will throw connection_policy.DBAReadTimeout = 0.000000000001 From ff22510331bc38bf39bdec513180b53903aeaca2 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 28 Apr 2025 21:06:39 -0700 Subject: [PATCH 122/152] fix tests and split up them up across to live tests pipelines --- .../aio/_cosmos_client_connection_async.py | 18 ++++++------- sdk/cosmos/azure-cosmos/azure/cosmos/auth.py | 4 +-- .../azure-cosmos/azure/cosmos/container.py | 3 +-- sdk/cosmos/azure-cosmos/pytest.ini | 3 ++- .../azure-cosmos/tests/test_change_feed.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_crud.py | 2 +- .../azure-cosmos/tests/test_crud_async.py | 2 +- .../test_per_partition_circuit_breaker_mm.py | 2 +- ..._per_partition_circuit_breaker_mm_async.py | 2 +- ...st_per_partition_circuit_breaker_sm_mrr.py | 2 +- ..._partition_circuit_breaker_sm_mrr_async.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_query.py | 2 +- .../azure-cosmos/tests/test_query_async.py | 2 +- sdk/cosmos/live-platform-matrix.json | 27 ++++++++++++------- sdk/cosmos/test-resources.bicep | 4 +-- 15 files changed, 43 insertions(+), 34 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 6278a8460f8d..088a334b6181 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2882,6 +2882,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: cont_prop = kwargs.pop("containerProperties", None) if cont_prop: cont_prop = await cont_prop() + # TODO: @tvaron3 move this logic as this isn't thread safe options["containerRID"] = cont_prop["_rid"] @@ -2923,11 +2924,18 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: else: raise SystemError("Unexpected query compatibility mode.") + # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, + documents._OperationType.SqlQuery, + options, partition_key_range_id) + request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) + # check if query has prefix partition key partition_key = options.get("partitionKey", None) isPrefixPartitionQuery = False partition_key_definition = None - if cont_prop and partition_key: + if cont_prop and partition_key is not None: pk_properties = cont_prop["partitionKey"] partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"]) if partition_key_definition.kind == "MultiHash" and \ @@ -2935,13 +2943,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: len(partition_key_definition['paths']) != len(partition_key)): isPrefixPartitionQuery = True - # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, - documents._OperationType.SqlQuery, - options, partition_key_range_id) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) - request_params.set_excluded_location_from_options(options) - if isPrefixPartitionQuery and partition_key_definition: # here get the overlapping ranges req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) @@ -3297,7 +3298,6 @@ async def DeleteAllItemsByPartitionKey( request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py index 497007f2dbc1..370d8b76bf17 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py @@ -152,8 +152,8 @@ def __get_authorization_token_using_resource_token(resource_tokens, path, resour segment = path_parts[i - 1] sub_path = "/".join(path_parts[:i]) if not segment in resource_types: - for path, resource_token in resource_tokens.items(): - if sub_path in path: + for resource_path, resource_token in resource_tokens.items(): + if sub_path in resource_path: return resource_tokens[path] return None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 66b2557cdca5..964df6dee354 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -681,11 +681,11 @@ def query_items( # pylint:disable=docstring-missing-param feed_options["populateQueryMetrics"] = populate_query_metrics if populate_index_metrics is not None: feed_options["populateIndexMetrics"] = populate_index_metrics + properties = self._get_properties() if partition_key is not None: partition_key_value = self._set_partition_key(partition_key) if self.__is_prefix_partitionkey(partition_key): kwargs["isPrefixPartitionQuery"] = True - properties = self._get_properties() kwargs["partitionKeyDefinition"] = properties["partitionKey"] kwargs["partitionKeyDefinition"]["partition_key"] = partition_key_value else: @@ -701,7 +701,6 @@ def query_items( # pylint:disable=docstring-missing-param feed_options["responseContinuationTokenLimitInKb"] = continuation_token_limit if response_hook and hasattr(response_hook, "clear"): response_hook.clear() - self._get_properties() feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] items = self.client_connection.QueryItems( database_or_container_link=self.container_link, diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 0db8e9cd12eb..6255aded49ad 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -5,4 +5,5 @@ markers = cosmosQuery: marks tests running queries on Cosmos DB live account. cosmosSplit: marks test where there are partition splits on CosmosDB live account. cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled. - cosmosPPCBMultiWrite: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled. + cosmosCircuitBreaker: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled. + cosmosCircuitBreakerMultiRegion: marks tests running on Cosmos DB live account with per partition circuit breaker enabled. diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py index 5c93e56be041..6b56701738ad 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py @@ -36,7 +36,7 @@ def round_time(): utc_now = datetime.now(timezone.utc) return utc_now - timedelta(microseconds=utc_now.microsecond) -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery @pytest.mark.unittest @pytest.mark.usefixtures("setup") diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index c9f8487a0198..e87a28a80049 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -45,7 +45,7 @@ def send(self, *args, **kwargs): response = RequestsTransportResponse(None, output) return response -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosLong class TestCRUDOperations(unittest.TestCase): """Python CRUD Tests. diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 7b0b4476befd..81cb5baf4d45 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -44,7 +44,7 @@ async def send(self, *args, **kwargs): response = AsyncioRequestsTransportResponse(None, output) return response -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosLong class TestCRUDOperationsAsync(unittest.IsolatedAsyncioTestCase): """Python CRUD Tests. diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 1299e18566f8..6ee803be24d5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -109,7 +109,7 @@ def perform_read_operation(operation, container, doc_id, pk, expected_uri): for _ in container.read_all_items(): pass -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreaker @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMM: host = test_config.TestConfig.host diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 8f614b58dd3e..cdb606fd258a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -199,7 +199,7 @@ async def cleanup_method(initialized_objects: List[Dict[str, Any]]): method_client: CosmosClient = obj["client"] await method_client.close() -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreaker @pytest.mark.asyncio @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMMAsync: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index a1e7bbe19f1d..c27fed2c0325 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -34,7 +34,7 @@ def setup_teardown(): created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrr: host = test_config.TestConfig.host diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index c945c4d2d9c0..aee621f3f3d9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -57,7 +57,7 @@ def validate_unhealthy_partitions(global_endpoint_manager, assert unhealthy_partitions == expected_unhealthy_partitions -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.asyncio @pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrrAsync: diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 8a17e76b4f6d..d1e6514bb06b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -17,7 +17,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index 9981b36fe266..bfb35bd95b79 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -18,7 +18,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey -@pytest.mark.cosmosPPCBMultiWrite +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestQueryAsync(unittest.IsolatedAsyncioTestCase): """Test to ensure escaping of non-ascii characters from partition key""" diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 58e90fb4bd8c..4e7250199a81 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -27,21 +27,13 @@ }, { "PPCBMultiWriteTestConfig": { - "Ubuntu2004_39_ppcb": { - "OSVmImage": "env:LINUXVMIMAGE", - "Pool": "env:LINUXPOOL", - "PythonVersion": "3.9", - "CoverageArg": "--disablecov", - "TestSamples": "false", - "TestMarkArgument": "cosmosPPCBMultiRegion" - }, "Ubuntu2004_313_ppcb": { "OSVmImage": "env:LINUXVMIMAGE", "Pool": "env:LINUXPOOL", "PythonVersion": "3.13", "CoverageArg": "--disablecov", "TestSamples": "false", - "TestMarkArgument": "cosmosPPCBMultiWrite" + "TestMarkArgument": "cosmosCircuitBreaker" } }, "ArmConfig": { @@ -50,6 +42,23 @@ } } }, + { + "PPCBMultiRegionTestConfig": { + "Ubuntu2004_39_ppcb": { + "OSVmImage": "env:LINUXVMIMAGE", + "Pool": "env:LINUXPOOL", + "PythonVersion": "3.9", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosCircuitBreakerMultiRegion" + } + }, + "ArmConfig": { + "MultiRegion": { + "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" + } + } + }, { "MacTestConfig": { "macos311_search_query": { diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 5824c936bb8f..a9635c29c546 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -13,7 +13,7 @@ param enableMultipleRegions bool = false param location string = resourceGroup().location @description('Whether Per Partition Circuit Breaker should be enabled.') -param ppcbEnabled bool = false +param circuitBreakerEnabled bool = false @description('The api version to be used by Bicep to create resources') param apiVersion string = '2023-04-15' @@ -104,6 +104,6 @@ resource accountName_roleAssignmentId 'Microsoft.DocumentDB/databaseAccounts/sql } } -output AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER bool = ppcbEnabled +output AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER bool = circuitBreakerEnabled output ACCOUNT_HOST string = reference(resourceId, apiVersion).documentEndpoint output ACCOUNT_KEY string = listKeys(resourceId, apiVersion).primaryMasterKey From d9c24c4e36512b0597099cc1b0e9630ee63c33c9 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 28 Apr 2025 21:42:31 -0700 Subject: [PATCH 123/152] fix tests and cspell --- .../azure/cosmos/_cosmos_client_connection.py | 8 -------- ...artition_endpoint_manager_circuit_breaker_core.py | 3 --- .../cosmos/aio/_cosmos_client_connection_async.py | 8 ++++---- sdk/cosmos/live-platform-matrix.json | 12 ++++++------ 4 files changed, 10 insertions(+), 21 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 9d7410f659dc..0292352090f2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2070,7 +2070,6 @@ def PatchItem( documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2162,7 +2161,6 @@ def _Batch( documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2226,7 +2224,6 @@ def DeleteAllItemsByPartitionKey( documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2398,7 +2395,6 @@ def ExecuteStoredProcedure( # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript, headers) - request_params.set_excluded_location_from_options(options) result, self.last_response_headers = self.__Post(path, request_params, params, headers, **kwargs) return result @@ -2732,7 +2728,6 @@ def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2777,7 +2772,6 @@ def Replace( # Replace will use WriteEndpoint since it uses PUT operation request_params = RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2820,7 +2814,6 @@ def Read( # Read will use ReadEndpoint since it uses GET operation request_params = RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2861,7 +2854,6 @@ def DeleteResource( # Delete will use WriteEndpoint since it uses DELETE operation request_params = RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 6e23efabbed3..f5bb1dfc78cd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -21,7 +21,6 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging import os from azure.cosmos import documents @@ -34,8 +33,6 @@ from azure.cosmos._constants import _Constants as Constants -logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreakerCore") - class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): """ This internal class implements the logic for partition endpoint management for diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 088a334b6181..ace038f11efd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2271,6 +2271,9 @@ def QueryItems( collection_id = base.GetResourceIdOrFullNameFromLink(database_or_container_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: + await kwargs["containerProperties"] + new_options = dict(options) + new_options["containerRID"] = self.__container_properties_cache[database_or_container_link]["_rid"] return ( await self.__QueryFeed( path, @@ -2279,7 +2282,7 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca lambda r: r["Documents"], lambda _, b: b, query, - options, + new_options, response_hook=response_hook, **kwargs ), @@ -2882,9 +2885,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: cont_prop = kwargs.pop("containerProperties", None) if cont_prop: cont_prop = await cont_prop() - # TODO: @tvaron3 move this logic as this isn't thread safe - options["containerRID"] = cont_prop["_rid"] - # Copy to make sure that default_headers won't be changed. diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 4e7250199a81..508e96034d66 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -26,8 +26,8 @@ } }, { - "PPCBMultiWriteTestConfig": { - "Ubuntu2004_313_ppcb": { + "CircuitBreakerMultiWriteTestConfig": { + "Ubuntu2004_313_circuit_breaker": { "OSVmImage": "env:LINUXVMIMAGE", "Pool": "env:LINUXPOOL", "PythonVersion": "3.13", @@ -38,13 +38,13 @@ }, "ArmConfig": { "MultiMaster": { - "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = $true }" } } }, { - "PPCBMultiRegionTestConfig": { - "Ubuntu2004_39_ppcb": { + "CircuitBreakerMultiRegionTestConfig": { + "Ubuntu2004_39_circuit_breaker": { "OSVmImage": "env:LINUXVMIMAGE", "Pool": "env:LINUXPOOL", "PythonVersion": "3.9", @@ -55,7 +55,7 @@ }, "ArmConfig": { "MultiRegion": { - "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; ppcbEnabled = $true }" + "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = $true }" } } }, From 6e1c8d5140ce07e0d808985de84fbcb2469e5eac Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 28 Apr 2025 23:55:18 -0700 Subject: [PATCH 124/152] add sync deleta all item partiton key tests --- .../azure/cosmos/_partition_health_tracker.py | 4 +- .../azure/cosmos/_routing/routing_range.py | 4 +- .../aio/_cosmos_client_connection_async.py | 2 +- .../tests/test_circuit_breaker_emulator.py | 268 ++++++++++++++++++ .../test_per_partition_circuit_breaker_mm.py | 2 - ..._per_partition_circuit_breaker_mm_async.py | 18 +- ...st_per_partition_circuit_breaker_sm_mrr.py | 6 +- ..._partition_circuit_breaker_sm_mrr_async.py | 28 +- 8 files changed, 284 insertions(+), 48 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 673e653ca9fc..03457a2967eb 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -123,7 +123,7 @@ def __init__(self) -> None: # partition -> regions -> health info self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} self.last_refresh = current_time_millis() - self.stale_partiton_lock = threading.Lock() + self.stale_partition_lock = threading.Lock() def mark_partition_unavailable(self, pk_range_wrapper: PartitionKeyRangeWrapper, location: str) -> None: # mark the partition key range as unavailable @@ -180,7 +180,7 @@ def check_stale_partition_info( if _should_mark_healthy_tentative(partition_health_info, current_time): # unhealthy or unhealthy tentative -> healthy tentative # only one request should be used to recover - with self.stale_partiton_lock: + with self.stale_partition_lock: if _should_mark_healthy_tentative(partition_health_info, current_time): # this will trigger one attempt to recover partition_health_info.transition_health_status(UNHEALTHY, current_time) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index dd566accedcf..1f75754f67db 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -100,10 +100,10 @@ def to_normalized_range(self): normalized_max = self.max if not self.isMinInclusive: - normalized_min = self.add_to_effective_partition_key(self.min, -1).upper() + normalized_min = self.add_to_effective_partition_key(self.min, -1) if self.isMaxInclusive: - normalized_max = self.add_to_effective_partition_key(self.max, 1).upper() + normalized_max = self.add_to_effective_partition_key(self.max, 1) return Range(normalized_min, normalized_max, True, False) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index ace038f11efd..c8bc8d8582af 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2271,7 +2271,7 @@ def QueryItems( collection_id = base.GetResourceIdOrFullNameFromLink(database_or_container_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: - await kwargs["containerProperties"] + await kwargs["containerProperties"]() new_options = dict(options) new_options["containerRID"] = self.__container_properties_cache[database_or_container_link]["_rid"] return ( diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py new file mode 100644 index 000000000000..5f0bb4cc1952 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py @@ -0,0 +1,268 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import os +import unittest +import uuid +from time import sleep + +import pytest +import pytest_asyncio +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey, _partition_health_tracker +from azure.cosmos import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport import FaultInjectionTransport +from test_per_partition_circuit_breaker_mm import perform_write_operation +from test_per_partition_circuit_breaker_mm_async import create_doc, PK_VALUE, create_errors +from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions +from tests.test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY + +COLLECTION = "created_collection" +@pytest_asyncio.fixture(scope="class", autouse=True) +def setup_teardown(): + client = CosmosClient(TestCircuitBreakerEmulator.host, TestCircuitBreakerEmulator.master_key) + created_database = client.get_database_client(TestCircuitBreakerEmulator.TEST_DATABASE_ID) + created_database.create_container(TestCircuitBreakerEmulator.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + # allow some time for the container to be created as this method is in different event loop + sleep(3) + yield + created_database.delete_container(TestCircuitBreakerEmulator.TEST_CONTAINER_SINGLE_PARTITION_ID) + +def create_custom_transport_mm(): + custom_transport = FaultInjectionTransport() + is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + return custom_transport + + +@pytest.mark.cosmosEmulator +@pytest.mark.usefixtures("setup_teardown") +class TestCircuitBreakerEmulator: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + + def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + def setup_info(self, error, mm=False): + expected_uri = self.host + uri_down = self.host.replace("localhost", "127.0.0.1") + custom_transport = create_custom_transport_mm() if mm else self.create_custom_transport_sm_mrr() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, + error) + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("error", create_errors()) + def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + validate_unhealthy_partitions(global_endpoint_manager, 0) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + perform_write_operation( + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + + @pytest.mark.parametrize("error", create_errors()) + def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(4): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + # writes should now succeed because going to the other region + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + + validate_unhealthy_partitions(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions(global_endpoint_manager, 0) + + + @pytest.mark.parametrize("error", create_errors()) + def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 1) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + @pytest.mark.parametrize("error", create_errors()) + def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_teardown, write_operation, error): + error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) + container = setup['col'] + fault_injection_container = custom_setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: FaultInjectionTransport.error_after_delay( + 0, + error + )) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + perform_write_operation(write_operation, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 6ee803be24d5..db89278bb229 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -21,7 +21,6 @@ @pytest.fixture(scope="class", autouse=True) def setup_teardown(): - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerMM.host, TestPerPartitionCircuitBreakerMM.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerMM.TEST_DATABASE_ID) @@ -33,7 +32,6 @@ def setup_teardown(): yield created_database.delete_container(TestPerPartitionCircuitBreakerMM.TEST_CONTAINER_SINGLE_PARTITION_ID) - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): doc = {'id': doc_id, diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index cdb606fd258a..2716d8d9cee2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -40,7 +40,6 @@ COLLECTION = "created_collection" @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_teardown(): - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, TestPerPartitionCircuitBreakerMMAsync.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) @@ -52,10 +51,8 @@ async def setup_teardown(): yield await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" -def write_operations_and_errors(): - write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH]# "delete_all_items_by_partition_key"] +def create_errors(): errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -63,6 +60,11 @@ def write_operations_and_errors(): status_code=error_code, message="Some injected error.")) errors.append(ServiceResponseError(message="Injected Service Response Error.")) + return errors + +def write_operations_and_errors(): + write_operations = [CREATE, UPSERT, REPLACE, DELETE, PATCH, BATCH] + errors = create_errors() params = [] for write_operation in write_operations: for error in errors: @@ -78,13 +80,7 @@ def create_doc(): def read_operations_and_errors(): read_operations = [READ, QUERY_PK, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK] - errors = [] - error_codes = [408, 500, 502, 503] - for error_code in error_codes: - errors.append(CosmosHttpResponseError( - status_code=error_code, - message="Some injected error.")) - errors.append(ServiceResponseError(message="Injected Service Response Error.")) + errors = create_errors() params = [] for read_operation in read_operations: for error in errors: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index c27fed2c0325..a066828fad09 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -20,9 +20,8 @@ from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" -@pytest_asyncio.fixture() +@pytest_asyncio.fixture(scope="class", autouse=True) def setup_teardown(): - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerSmMrr.host, TestPerPartitionCircuitBreakerSmMrr.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrr.TEST_DATABASE_ID) created_database.create_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID, @@ -32,7 +31,6 @@ def setup_teardown(): sleep(3) yield created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.usefixtures("setup_teardown") @@ -45,7 +43,7 @@ class TestPerPartitionCircuitBreakerSmMrr: def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", - preferred_locations=["Write Region", "Read Region"], + preferred_locations=[REGION_1, REGION_2], transport=custom_transport, **kwargs) db = client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index aee621f3f3d9..247e77289cfe 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -23,7 +23,7 @@ REGION_2, REGION_1 COLLECTION = "created_collection" -@pytest_asyncio.fixture() +@pytest_asyncio.fixture(scope="class", autouse=True) async def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key) @@ -36,7 +36,6 @@ async def setup_teardown(): yield await created_database.delete_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): @@ -69,7 +68,7 @@ class TestPerPartitionCircuitBreakerSmMrrAsync: async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", - preferred_locations=["Write Region", "Read Region"], + preferred_locations=[REGION_1, REGION_2], transport=custom_transport, **kwargs) db = client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) @@ -80,29 +79,6 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] await method_client.close() - async def create_custom_transport_sm_mrr(self): - custom_transport = FaultInjectionTransportAsync() - # Inject rule to disallow writes in the read-only region - is_write_operation_in_read_region_predicate = lambda \ - r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) - - custom_transport.add_fault( - is_write_operation_in_read_region_predicate, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) - - # Inject topology transformation that would make Emulator look like a single write region - # account with two read regions - is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation = \ - lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( - write_region_name="Write Region", - read_region_name="Read Region", - inner=inner) - custom_transport.add_response_transformation( - is_get_account_predicate, - emulator_as_multi_region_sm_account_transformation) - return custom_transport - async def setup_info(self, error): expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) From b83b1c9ddc8ef0fa07eafcabe5f9805fafa90005 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 29 Apr 2025 00:11:32 -0700 Subject: [PATCH 125/152] fix tests --- sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py index 5f0bb4cc1952..f083a83397f4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py @@ -17,7 +17,7 @@ from test_per_partition_circuit_breaker_mm import perform_write_operation from test_per_partition_circuit_breaker_mm_async import create_doc, PK_VALUE, create_errors from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions -from tests.test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY +from test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY COLLECTION = "created_collection" @pytest_asyncio.fixture(scope="class", autouse=True) From 09934aa7b31459e0877251ec578dc54f777ea7cb Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Tue, 29 Apr 2025 01:07:39 -0700 Subject: [PATCH 126/152] fix ci tests --- ...tition_endpoint_manager_circuit_breaker.py | 1 - ...n_endpoint_manager_circuit_breaker_core.py | 5 +- .../tests/_fault_injection_transport.py | 9 ++- .../tests/test_circuit_breaker_emulator.py | 55 ++++++++++--------- .../test_per_partition_circuit_breaker_mm.py | 9 ++- ..._per_partition_circuit_breaker_mm_async.py | 2 - ..._partition_circuit_breaker_sm_mrr_async.py | 2 +- 7 files changed, 46 insertions(+), 37 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 87748b0bacdf..5a340eaf6f6c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -53,7 +53,6 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - print(request.headers) properties = self.Client._container_properties_cache[container_rid] # pylint: disable=protected-access # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index f5bb1dfc78cd..78c03763e1b9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -58,8 +58,9 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: and documents._OperationType.IsWriteOperation(request.operation_type)): # pylint: disable=protected-access return False - if (request.resource_type != ResourceType.Document - or request.operation_type == documents._OperationType.QueryPlan): # pylint: disable=protected-access + if ((request.resource_type != ResourceType.Document + and request.resource_type != ResourceType.PartitionKey) + or request.operation_type == documents._OperationType.QueryPlan): # pylint: disable=protected-access return False # this is for certain cross partition queries and read all items where we cannot discern partition information diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 6a56daf4a789..f745f3247b5d 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -25,6 +25,7 @@ import json import logging import sys +from importlib.resources import is_resource from time import sleep from typing import Callable, Optional, Any, Dict, List, MutableMapping @@ -148,6 +149,12 @@ def predicate_is_operation_type(r: HttpRequest, operation_type: str) -> bool: return is_operation_type + @staticmethod + def predicate_is_resource_type(r: HttpRequest, resource_type: str) -> bool: + is_resource_type = r.headers.get(HttpHeaders.ThinClientProxyResourceType) == resource_type + + return is_resource_type + @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: is_write_document_operation = documents._OperationType.IsWriteOperation( @@ -211,7 +218,7 @@ def transform_topology_mwr( first_region_name: str, second_region_name: str, inner: Callable[[], RequestsTransportResponse], - first_region_url: str = None, + first_region_url: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1"), second_region_url: str = test_config.TestConfig.local_host ) -> RequestsTransportResponse: diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py index f083a83397f4..6d1c99fb0116 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py @@ -10,18 +10,21 @@ from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _partition_health_tracker +from azure.cosmos import PartitionKey, _partition_health_tracker, documents from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos.http_constants import ResourceType from test_per_partition_circuit_breaker_mm import perform_write_operation -from test_per_partition_circuit_breaker_mm_async import create_doc, PK_VALUE, create_errors -from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions -from test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY +from test_per_partition_circuit_breaker_mm_async import (create_doc, PK_VALUE, create_errors, + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + validate_unhealthy_partitions as validate_unhealthy_partitions_mm) +from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions as validate_unhealthy_partitions_sm_mrr COLLECTION = "created_collection" @pytest_asyncio.fixture(scope="class", autouse=True) def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestCircuitBreakerEmulator.host, TestCircuitBreakerEmulator.master_key) created_database = client.get_database_client(TestCircuitBreakerEmulator.TEST_DATABASE_ID) created_database.create_container(TestCircuitBreakerEmulator.TEST_CONTAINER_SINGLE_PARTITION_ID, @@ -31,14 +34,15 @@ def setup_teardown(): sleep(3) yield created_database.delete_container(TestCircuitBreakerEmulator.TEST_CONTAINER_SINGLE_PARTITION_ID) + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" def create_custom_transport_mm(): custom_transport = FaultInjectionTransport() is_get_account_predicate = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) emulator_as_multi_write_region_account_transformation = \ lambda r, inner: FaultInjectionTransport.transform_topology_mwr( - first_region_name="First Region", - second_region_name="Second Region", + first_region_name="Write Region", + second_region_name="Read Region", inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, @@ -93,25 +97,26 @@ def setup_info(self, error, mm=False): custom_transport = create_custom_transport_mm() if mm else self.create_custom_transport_sm_mrr() # two documents targeted to same partition, one will always fail and the other will succeed doc = create_doc() - predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + predicate = lambda r: (FaultInjectionTransport.predicate_is_resource_type(r, ResourceType.Collection) and + FaultInjectionTransport.predicate_is_operation_type(r, documents._OperationType.Delete) and FaultInjectionTransport.predicate_targets_region(r, uri_down)) custom_transport.add_fault(predicate, error) - custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) - setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host) + custom_setup = self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host, multiple_write_locations=mm) + setup = self.setup_method_with_custom_transport(None, default_endpoint=self.host, multiple_write_locations=mm) return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate @pytest.mark.parametrize("error", create_errors()) def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) - container = setup['col'] fault_injection_container = custom_setup['col'] - global_endpoint_manager = container.client_connection._global_endpoint_manager + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): - validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): perform_write_operation( DELETE_ALL_ITEMS_BY_PARTITION_KEY, @@ -122,15 +127,15 @@ def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(self, set expected_uri, ) - validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) @pytest.mark.parametrize("error", create_errors()) def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True) - container = setup['col'] fault_injection_container = custom_setup['col'] + container = setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: @@ -142,7 +147,7 @@ def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, set expected_uri) assert exc_info.value == error - validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) # writes should fail but still be tracked for i in range(4): @@ -163,7 +168,7 @@ def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, set PK_VALUE, expected_uri) - validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) # remove faults and reduce initial recover time and perform a write original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 @@ -177,15 +182,15 @@ def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, set uri_down) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time - validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) @pytest.mark.parametrize("error", create_errors()) def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda, mm=True) - container = setup['col'] fault_injection_container = custom_setup['col'] + container = setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 @@ -193,7 +198,7 @@ def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_tear try: # writes should fail but still be tracked and mark unavailable a partition after crossing threshold for i in range(10): - validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) if i == 4 or i == 8: # perform some successful creates to reset consecutive counter # remove faults and perform a write @@ -214,7 +219,7 @@ def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_tear expected_uri) assert exc_info.value == error - validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) finally: os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" @@ -222,11 +227,11 @@ def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_tear _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 @pytest.mark.parametrize("error", create_errors()) - def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_teardown, write_operation, error): + def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) - container = setup['col'] fault_injection_container = custom_setup['col'] + container = setup['col'] global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager # lower minimum requests for testing _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 @@ -234,7 +239,7 @@ def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_tear try: # writes should fail but still be tracked and mark unavailable a partition after crossing threshold for i in range(10): - validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) if i == 4 or i == 8: # perform some successful creates to reset consecutive counter # remove faults and perform a write @@ -247,7 +252,7 @@ def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_tear )) else: with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: - perform_write_operation(write_operation, + perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, container, fault_injection_container, str(uuid.uuid4()), @@ -255,7 +260,7 @@ def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_tear expected_uri) assert exc_info.value == error - validate_unhealthy_partitions(global_endpoint_manager, 0) + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) finally: os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index db89278bb229..274895d83581 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -17,6 +17,7 @@ QUERY_PK, QUERY, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS, REGION_1, REGION_2, \ write_operations_and_errors, validate_unhealthy_partitions, read_operations_and_errors, PK_VALUE, operations, \ create_doc +from tests.test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY @pytest.fixture(scope="class", autouse=True) @@ -65,11 +66,9 @@ def perform_write_operation(operation, container, fault_injection_container, doc ] resp = fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) # this will need to be emulator only - # elif operation == "delete_all_items_by_partition_key": - # await container.create_item(body=doc) - # await container.create_item(body=doc) - # await container.create_item(body=doc) - # resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + elif operation == DELETE_ALL_ITEMS_BY_PARTITION_KEY: + container.create_item(body=doc) + resp = fault_injection_container.delete_all_items_by_partition_key(pk) if resp: validate_response_uri(resp, expected_uri) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 2716d8d9cee2..13672381ea9f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -134,8 +134,6 @@ async def perform_write_operation(operation, container, fault_injection_containe resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) # this will need to be emulator only elif operation == DELETE_ALL_ITEMS_BY_PARTITION_KEY: - await container.create_item(body=doc) - await container.create_item(body=doc) await container.create_item(body=doc) resp = await fault_injection_container.delete_all_items_by_partition_key(pk) if resp: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index 247e77289cfe..0e35314e629a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -13,7 +13,7 @@ import test_config from azure.cosmos import PartitionKey, _partition_health_tracker, _location_cache -from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos._partition_health_tracker import UNHEALTHY_TENTATIVE, UNHEALTHY, HEALTH_STATUS from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport_async import FaultInjectionTransportAsync From 3d71dc55ff7a2f2d75e3803c69fc134bf412d7b1 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 29 Apr 2025 11:06:34 -0700 Subject: [PATCH 127/152] fix tests --- sdk/cosmos/azure-cosmos/azure/cosmos/auth.py | 2 +- .../tests/test_per_partition_circuit_breaker_mm.py | 2 ++ .../tests/test_per_partition_circuit_breaker_sm_mrr_async.py | 3 ++- sdk/cosmos/live-platform-matrix.json | 4 ++-- sdk/cosmos/test-resources.bicep | 4 ++-- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py index 370d8b76bf17..221109a35bab 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py @@ -154,6 +154,6 @@ def __get_authorization_token_using_resource_token(resource_tokens, path, resour if not segment in resource_types: for resource_path, resource_token in resource_tokens.items(): if sub_path in resource_path: - return resource_tokens[path] + return resource_tokens[resource_path] return None diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 274895d83581..5fb0e864da6a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -20,8 +20,10 @@ from tests.test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY +logger = logging.getLogger('test') @pytest.fixture(scope="class", autouse=True) def setup_teardown(): + logger.info(os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER")) client = CosmosClient(TestPerPartitionCircuitBreakerMM.host, TestPerPartitionCircuitBreakerMM.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerMM.TEST_DATABASE_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index 0e35314e629a..95768b5595c8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -23,9 +23,10 @@ REGION_2, REGION_1 COLLECTION = "created_collection" +logger = logging.getLogger('test') @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_teardown(): - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + logger.info(os.getenv("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER")) client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 508e96034d66..c4fe1426b1c6 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -38,7 +38,7 @@ }, "ArmConfig": { "MultiMaster": { - "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = $true }" + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = 'True' }" } } }, @@ -55,7 +55,7 @@ }, "ArmConfig": { "MultiRegion": { - "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = $true }" + "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; circuitBreakerEnabled = 'True' }" } } }, diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index a9635c29c546..735c1a0e66ee 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -13,7 +13,7 @@ param enableMultipleRegions bool = false param location string = resourceGroup().location @description('Whether Per Partition Circuit Breaker should be enabled.') -param circuitBreakerEnabled bool = false +param circuitBreakerEnabled string = 'False' @description('The api version to be used by Bicep to create resources') param apiVersion string = '2023-04-15' @@ -104,6 +104,6 @@ resource accountName_roleAssignmentId 'Microsoft.DocumentDB/databaseAccounts/sql } } -output AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER bool = circuitBreakerEnabled +output AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER string = circuitBreakerEnabled output ACCOUNT_HOST string = reference(resourceId, apiVersion).documentEndpoint output ACCOUNT_KEY string = listKeys(resourceId, apiVersion).primaryMasterKey From d09a4452b3737ec4cbdbf23f9b04379ba88396c2 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 29 Apr 2025 11:54:06 -0700 Subject: [PATCH 128/152] fix ci --- .../tests/test_per_partition_circuit_breaker_mm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 5fb0e864da6a..f84fea27163c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -1,5 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import logging import os import unittest import uuid @@ -17,7 +18,7 @@ QUERY_PK, QUERY, CHANGE_FEED, CHANGE_FEED_PK, CHANGE_FEED_EPK, READ_ALL_ITEMS, REGION_1, REGION_2, \ write_operations_and_errors, validate_unhealthy_partitions, read_operations_and_errors, PK_VALUE, operations, \ create_doc -from tests.test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY +from test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY logger = logging.getLogger('test') From 1eb29e9ba943953246de845010c89ac7dff1a062 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 29 Apr 2025 12:37:34 -0700 Subject: [PATCH 129/152] fix ci --- .../_global_partition_endpoint_manager_circuit_breaker_core.py | 3 +-- .../azure/cosmos/aio/_global_endpoint_manager_async.py | 2 +- .../tests/test_per_partition_circuit_breaker_sm_mrr_async.py | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 78c03763e1b9..548430dd6806 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -58,8 +58,7 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: and documents._OperationType.IsWriteOperation(request.operation_type)): # pylint: disable=protected-access return False - if ((request.resource_type != ResourceType.Document - and request.resource_type != ResourceType.PartitionKey) + if (request.resource_type not in (ResourceType.Document, ResourceType.PartitionKey) or request.operation_type == documents._OperationType.QueryPlan): # pylint: disable=protected-access return False diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 773f6f8e7d1d..5a7aea5330ba 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -38,7 +38,7 @@ # pylint: disable=protected-access -logger = logging.getLogger("azure.cosmos.aio_GlobalEndpointManager") +logger = logging.getLogger("azure.cosmos.aio._GlobalEndpointManager") class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attributes """ diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index 95768b5595c8..a2e1d6418cc6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. import asyncio +import logging import os import unittest import uuid From e9cc6f198b93dda0bfaf4090023a0b86fa9bf4f1 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 29 Apr 2025 16:15:10 -0700 Subject: [PATCH 130/152] refactor tests, removed print statements, and added async emulator delete all items by partition key tests --- .../azure/cosmos/_partition_health_tracker.py | 6 - .../azure/cosmos/aio/_retry_utility_async.py | 1 - .../tests/test_circuit_breaker_emulator.py | 3 +- .../test_circuit_breaker_emulator_async.py | 274 ++++++++++++++++++ .../test_per_partition_circuit_breaker_mm.py | 4 +- ..._per_partition_circuit_breaker_mm_async.py | 3 + ...st_per_partition_circuit_breaker_sm_mrr.py | 93 +----- ..._partition_circuit_breaker_sm_mrr_async.py | 105 +------ 8 files changed, 287 insertions(+), 202 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 03457a2967eb..d086d9d2b497 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -256,9 +256,6 @@ def add_failure( failure_rate_threshold, consecutive_failure_threshold ) - print(pk_range_wrapper) - print(location) - print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) def _check_thresholds( self, @@ -297,9 +294,6 @@ def add_success(self, pk_range_wrapper: PartitionKeyRangeWrapper, operation_type else: health_info.read_success_count += 1 health_info.read_consecutive_failure_count = 0 - print(pk_range_wrapper) - print(location) - print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) self._transition_health_status_on_success(pk_range_wrapper, location) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index b337d409bce1..82f09adbe4f5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -277,7 +277,6 @@ async def send(self, request): start_time = time.time() try: _configure_timeout(request, absolute_timeout, per_request_timeout) - print("RetryUtility - Sending request ", request_params.operation_type, request_params.resource_type) response = await self.next.send(request) break except ClientAuthenticationError: # pylint:disable=try-except-raise diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py index 6d1c99fb0116..3e450191fda2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py @@ -6,7 +6,6 @@ from time import sleep import pytest -import pytest_asyncio from azure.core.exceptions import ServiceResponseError import test_config @@ -22,7 +21,7 @@ from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions as validate_unhealthy_partitions_sm_mrr COLLECTION = "created_collection" -@pytest_asyncio.fixture(scope="class", autouse=True) +@pytest.fixture(scope="class", autouse=True) def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestCircuitBreakerEmulator.host, TestCircuitBreakerEmulator.master_key) diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py new file mode 100644 index 000000000000..29d2323381fb --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py @@ -0,0 +1,274 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from time import sleep + +import pytest +import pytest_asyncio +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey, _partition_health_tracker, documents +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.cosmos.http_constants import ResourceType +from test_per_partition_circuit_breaker_mm_async import (create_doc, PK_VALUE, create_errors, + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + validate_unhealthy_partitions as validate_unhealthy_partitions_mm, + perform_write_operation) +from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions as validate_unhealthy_partitions_sm_mrr +from _fault_injection_transport_async import FaultInjectionTransportAsync + +COLLECTION = "created_collection" +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestCircuitBreakerEmulatorAsync.host, TestCircuitBreakerEmulatorAsync.master_key) + created_database = client.get_database_client(TestCircuitBreakerEmulatorAsync.TEST_DATABASE_ID) + created_database.create_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + # allow some time for the container to be created as this method is in different event loop + sleep(3) + yield + created_database.delete_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +async def create_custom_transport_mm(): + custom_transport = FaultInjectionTransportAsync() + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="Write Region", + second_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + return custom_transport + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_teardown") +class TestCircuitBreakerEmulatorAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + async def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + + async def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransportAsync() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransportAsync.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + async def setup_info(self, error, mm=False): + expected_uri = self.host + uri_down = self.host.replace("localhost", "127.0.0.1") + custom_transport = create_custom_transport_mm() if mm else self.create_custom_transport_sm_mrr() + # two documents targeted to same partition, one will always fail and the other will succeed + doc = create_doc() + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_resource_type(r, ResourceType.Collection) and + FaultInjectionTransportAsync.predicate_is_operation_type(r, documents._OperationType.Delete) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + custom_transport.add_fault(predicate, error) + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host, multiple_write_locations=mm) + setup = await self.setup_method_with_custom_transport(None, default_endpoint=self.host, multiple_write_locations=mm) + return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await perform_write_operation( + DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri, + ) + + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(4): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + # writes should now succeed because going to the other region + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) + # remove faults and reduce initial recover time and perform a write + original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 + custom_transport.faults = [] + try: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + uri_down) + finally: + _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_mm(global_endpoint_manager, 1) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + @pytest.mark.parametrize("error", create_errors()) + async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) + setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) + fault_injection_container = custom_setup['col'] + container = setup['col'] + global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager + # lower minimum requests for testing + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" + try: + # writes should fail but still be tracked and mark unavailable a partition after crossing threshold + for i in range(10): + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + if i == 4 or i == 8: + # perform some successful creates to reset consecutive counter + # remove faults and perform a write + custom_transport.faults = [] + fault_injection_container.upsert_item(body=doc) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + else: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)) as exc_info: + await perform_write_operation(DELETE_ALL_ITEMS_BY_PARTITION_KEY, + container, + fault_injection_container, + str(uuid.uuid4()), + PK_VALUE, + expected_uri) + assert exc_info.value == error + + validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + + finally: + os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" + # restore minimum requests + _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index f84fea27163c..8825cf5a0705 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -21,10 +21,8 @@ from test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY -logger = logging.getLogger('test') @pytest.fixture(scope="class", autouse=True) def setup_teardown(): - logger.info(os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER")) client = CosmosClient(TestPerPartitionCircuitBreakerMM.host, TestPerPartitionCircuitBreakerMM.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerMM.TEST_DATABASE_ID) @@ -181,6 +179,7 @@ def test_write_consecutive_failure_threshold(self, setup_teardown, write_operati _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) + @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay( @@ -273,6 +272,7 @@ def test_write_failure_rate_threshold(self, setup_teardown, write_operation, err # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay( diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 13672381ea9f..f72d2031bf15 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -40,6 +40,7 @@ COLLECTION = "created_collection" @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, TestPerPartitionCircuitBreakerMMAsync.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) @@ -284,6 +285,7 @@ async def setup_info(self, error): return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate + @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( @@ -382,6 +384,7 @@ async def test_write_failure_rate_threshold_async(self, setup_teardown, write_op _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 await cleanup_method([custom_setup, setup]) + @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) async def test_read_failure_rate_threshold_async(self, setup_teardown, read_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index a066828fad09..813d12f86524 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -16,12 +16,13 @@ from _fault_injection_transport import FaultInjectionTransport from test_per_partition_circuit_breaker_mm import perform_write_operation, perform_read_operation from test_per_partition_circuit_breaker_mm_async import create_doc, PK_VALUE, write_operations_and_errors, \ - read_operations_and_errors, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations, REGION_2, REGION_1 + operations, REGION_2, REGION_1 from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" @pytest_asyncio.fixture(scope="class", autouse=True) def setup_teardown(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerSmMrr.host, TestPerPartitionCircuitBreakerSmMrr.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrr.TEST_DATABASE_ID) created_database.create_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID, @@ -86,56 +87,6 @@ def test_write_consecutive_failure_threshold(self, setup_teardown, write_operati validate_unhealthy_partitions(global_endpoint_manager, 0) - - @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation, error): - error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) - setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) - container = setup['col'] - fault_injection_container = custom_setup['col'] - - global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - - # create a document to read - container.create_item(body=doc) - - # reads should fail over and only the relevant partition should be marked as unavailable - perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - doc['pk'], - expected_uri) - # partition should not have been marked unavailable after one error - validate_unhealthy_partitions(global_endpoint_manager, 0) - - for i in range(10): - perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - doc['pk'], - expected_uri) - - # the partition should have been marked as unavailable after breaking read threshold - if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): - # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 - else: - expected_unhealthy_partitions = 1 - validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) - # remove faults and reduce initial recover time and perform a read - original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME - _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 - custom_transport.faults = [] - try: - perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - doc['pk'], - uri_down) - finally: - _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time - validate_unhealthy_partitions(global_endpoint_manager, 0) - @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) def test_write_failure_rate_threshold(self, setup_teardown, write_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) @@ -177,44 +128,6 @@ def test_write_failure_rate_threshold(self, setup_teardown, write_operation, err # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error): - error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) - setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) - container = setup['col'] - fault_injection_container = custom_setup['col'] - container.upsert_item(body=doc) - global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - # lower minimum requests for testing - _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 - os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" - try: - if isinstance(error, ServiceResponseError): - # service response error retries in region 3 additional times before failing over - num_operations = 2 - else: - num_operations = 8 - for i in range(num_operations): - validate_unhealthy_partitions(global_endpoint_manager, 0) - # read will fail and retry in other region - perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - PK_VALUE, - expected_uri) - if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): - # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 - else: - expected_unhealthy_partitions = 1 - - validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) - finally: - os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" - # restore minimum requests - _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - - @pytest.mark.parametrize("read_operation, write_operation", operations()) def test_service_request_error(self, read_operation, write_operation): # the region should be tried 4 times before failing over and mark the partition as unavailable @@ -253,6 +166,8 @@ def test_service_request_error(self, read_operation, write_operation): custom_transport.add_fault(predicate, lambda r: FaultInjectionTransport.error_region_down()) + # The global endpoint would be used for the write operation + expected_uri = self.host perform_write_operation(write_operation, container, fault_injection_container, diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index a2e1d6418cc6..0e8562d12394 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -1,7 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. import asyncio -import logging import os import unittest import uuid @@ -19,15 +18,11 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport_async import FaultInjectionTransportAsync from test_per_partition_circuit_breaker_mm_async import perform_write_operation, create_doc, PK_VALUE, \ - write_operations_and_errors, \ - cleanup_method, read_operations_and_errors, perform_read_operation, CHANGE_FEED, QUERY, READ_ALL_ITEMS, operations, \ - REGION_2, REGION_1 + write_operations_and_errors, cleanup_method, perform_read_operation, operations, REGION_2, REGION_1 COLLECTION = "created_collection" -logger = logging.getLogger('test') @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_teardown(): - logger.info(os.getenv("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER")) client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, @@ -123,59 +118,6 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w await cleanup_method([custom_setup, setup]) - @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): - error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 0, - error - )) - setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) - container = setup['col'] - fault_injection_container = custom_setup['col'] - - global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - - # create a document to read - await container.create_item(body=doc) - - # reads should fail over and only the relevant partition should be marked as unavailable - await perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - doc['pk'], - expected_uri) - # partition should not have been marked unavailable after one error - validate_unhealthy_partitions(global_endpoint_manager, 0) - - for i in range(10): - await perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - doc['pk'], - expected_uri) - - # the partition should have been marked as unavailable after breaking read threshold - if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): - # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 - else: - expected_unhealthy_partitions = 1 - validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) - # remove faults and reduce initial recover time and perform a read - original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME - _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 - custom_transport.faults = [] - try: - await perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - doc['pk'], - uri_down) - finally: - _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time - validate_unhealthy_partitions(global_endpoint_manager, 0) - await cleanup_method([custom_setup, setup]) - @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) async def test_write_failure_rate_threshold_async(self, setup_teardown, write_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( @@ -221,48 +163,6 @@ async def test_write_failure_rate_threshold_async(self, setup_teardown, write_op _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 await cleanup_method([custom_setup, setup]) - @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - async def test_read_failure_rate_threshold_async(self, setup_teardown, read_operation, error): - error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 0, - error - )) - setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) - container = setup['col'] - fault_injection_container = custom_setup['col'] - await container.upsert_item(body=doc) - global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - # lower minimum requests for testing - _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 8 - os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "80" - try: - if isinstance(error, ServiceResponseError): - # service response error retries in region 3 additional times before failing over - num_operations = 2 - else: - num_operations = 8 - for i in range(num_operations): - validate_unhealthy_partitions(global_endpoint_manager, 0) - # read will fail and retry in other region - await perform_read_operation(read_operation, - fault_injection_container, - doc['id'], - PK_VALUE, - expected_uri) - if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): - # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 - else: - expected_unhealthy_partitions = 1 - - validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) - finally: - os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" - # restore minimum requests - _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - await cleanup_method([custom_setup, setup]) - - @pytest.mark.parametrize("read_operation, write_operation", operations()) async def test_service_request_error_async(self, read_operation, write_operation): # the region should be tried 4 times before failing over and mark the partition as unavailable @@ -300,7 +200,8 @@ async def test_service_request_error_async(self, read_operation, write_operation custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) - + # The global endpoint would be used for the write operation + expected_uri = self.host await perform_write_operation(write_operation, container, fault_injection_container, From adf1cadf566e4a2380744cae46438dda9d0a803e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 29 Apr 2025 16:26:11 -0700 Subject: [PATCH 131/152] fix tests --- .../azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py | 1 - .../tests/test_per_partition_circuit_breaker_mm_async.py | 1 - .../tests/test_per_partition_circuit_breaker_sm_mrr.py | 1 - 3 files changed, 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 8825cf5a0705..269efa678113 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -1,6 +1,5 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import logging import os import unittest import uuid diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index f72d2031bf15..1e890b0ff2a5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -40,7 +40,6 @@ COLLECTION = "created_collection" @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_teardown(): - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, TestPerPartitionCircuitBreakerMMAsync.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index 813d12f86524..4d02597f690c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -22,7 +22,6 @@ COLLECTION = "created_collection" @pytest_asyncio.fixture(scope="class", autouse=True) def setup_teardown(): - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerSmMrr.host, TestPerPartitionCircuitBreakerSmMrr.master_key) created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrr.TEST_DATABASE_ID) created_database.create_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID, From c4e28f8ce39c095a2af8a3d9e1bb8028d20d9300 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Tue, 29 Apr 2025 16:37:19 -0700 Subject: [PATCH 132/152] fix ci tests --- ...n_endpoint_manager_circuit_breaker_async.py | 1 - .../tests/_fault_injection_transport_async.py | 6 ++++++ .../test_circuit_breaker_emulator_async.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 0140c04ab033..7a6a06c3e8c7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -51,7 +51,6 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - print(request.headers) properties = self.client._container_properties_cache[container_rid] # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 95a46a7696e0..ce10504164bb 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -26,6 +26,7 @@ import json import logging import sys +from importlib.resources import is_resource from typing import Callable, Optional, Any, Dict, List, Awaitable, MutableMapping import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse @@ -142,6 +143,11 @@ def predicate_is_document_operation(r: HttpRequest) -> bool: return is_document_operation + @staticmethod + def predicate_is_resource_type(r: HttpRequest, resource_type: str) -> bool: + is_resource_type = r.headers.get(HttpHeaders.ThinClientProxyResourceType) == resource_type + return is_resource_type + @staticmethod def predicate_is_operation_type(r: HttpRequest, operation_type: str) -> bool: is_operation_type = r.headers.get(HttpHeaders.ThinClientProxyOperationType) == operation_type diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py index 29d2323381fb..b0096d63fe9d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py @@ -28,13 +28,13 @@ async def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestCircuitBreakerEmulatorAsync.host, TestCircuitBreakerEmulatorAsync.master_key) created_database = client.get_database_client(TestCircuitBreakerEmulatorAsync.TEST_DATABASE_ID) - created_database.create_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + await created_database.create_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop sleep(3) yield - created_database.delete_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await created_database.delete_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" async def create_custom_transport_mm(): @@ -96,7 +96,7 @@ async def create_custom_transport_sm_mrr(self): async def setup_info(self, error, mm=False): expected_uri = self.host uri_down = self.host.replace("localhost", "127.0.0.1") - custom_transport = create_custom_transport_mm() if mm else self.create_custom_transport_sm_mrr() + custom_transport = await create_custom_transport_mm() if mm else await self.create_custom_transport_sm_mrr() # two documents targeted to same partition, one will always fail and the other will succeed doc = create_doc() predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_resource_type(r, ResourceType.Collection) and @@ -108,7 +108,7 @@ async def setup_info(self, error, mm=False): return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate @pytest.mark.parametrize("error", create_errors()) - async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): + async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm_async(self, setup_teardown, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) fault_injection_container = custom_setup['col'] @@ -132,7 +132,7 @@ async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm(sel @pytest.mark.parametrize("error", create_errors()) - async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm_async(self, setup_teardown, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True) fault_injection_container = custom_setup['col'] @@ -187,7 +187,7 @@ async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm(sel @pytest.mark.parametrize("error", create_errors()) - async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setup_teardown, error): + async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm_async(self, setup_teardown, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda, mm=True) fault_injection_container = custom_setup['col'] @@ -204,7 +204,7 @@ async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setu # perform some successful creates to reset consecutive counter # remove faults and perform a write custom_transport.faults = [] - fault_injection_container.upsert_item(body=doc) + await fault_injection_container.upsert_item(body=doc) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, @@ -228,7 +228,7 @@ async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm(self, setu _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 @pytest.mark.parametrize("error", create_errors()) - async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_teardown, error): + async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm_async(self, setup_teardown, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = await self.setup_info(error_lambda) fault_injection_container = custom_setup['col'] @@ -245,7 +245,7 @@ async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setu # perform some successful creates to reset consecutive counter # remove faults and perform a write custom_transport.faults = [] - fault_injection_container.upsert_item(body=doc) + await fault_injection_container.upsert_item(body=doc) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, From 416bb910dd36ebb29794d8a1e6bae0683c1ecd6d Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 1 May 2025 08:05:12 -0400 Subject: [PATCH 133/152] fix live tests --- sdk/cosmos/azure-cosmos/tests/test_crud_async.py | 1 - .../azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py | 2 +- .../tests/test_per_partition_circuit_breaker_mm_async.py | 2 +- .../tests/test_per_partition_circuit_breaker_sm_mrr.py | 2 +- .../tests/test_per_partition_circuit_breaker_sm_mrr_async.py | 2 +- 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 81cb5baf4d45..d39c944da309 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -4,7 +4,6 @@ """End-to-end test. """ -import logging import os import time import unittest diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 269efa678113..ad3208a2d03b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -29,7 +29,7 @@ def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - sleep(3) + sleep(6) yield created_database.delete_container(TestPerPartitionCircuitBreakerMM.TEST_CONTAINER_SINGLE_PARTITION_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 1e890b0ff2a5..ae421b2ddb15 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -47,7 +47,7 @@ async def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - await asyncio.sleep(3) + await asyncio.sleep(6) yield await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index 4d02597f690c..89bf93266ec9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -28,7 +28,7 @@ def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - sleep(3) + sleep(6) yield created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index 0e8562d12394..b7a5fd7de349 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -29,7 +29,7 @@ async def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - await asyncio.sleep(3) + await asyncio.sleep(6) yield await created_database.delete_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() From 399df8a290e17088770572efc12956a91827ec1e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 1 May 2025 11:36:58 -0400 Subject: [PATCH 134/152] fix tests --- .../azure-cosmos/azure/cosmos/_location_cache.py | 2 +- .../azure-cosmos/azure/cosmos/_retry_utility.py | 15 ++++++++++++--- .../azure/cosmos/aio/_retry_utility_async.py | 2 -- sdk/cosmos/azure-cosmos/tests/test_change_feed.py | 5 +++-- .../azure-cosmos/tests/test_change_feed_async.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_crud.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_crud_async.py | 2 +- ...test_per_partition_circuit_breaker_mm_async.py | 3 ++- sdk/cosmos/azure-cosmos/tests/test_query.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_query_async.py | 2 +- 10 files changed, 23 insertions(+), 14 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 2aef0d6b9656..d4af5281da9e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -142,7 +142,7 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: List[Re return applicable_regional_routing_contexts -def current_time_millis(): +def current_time_millis() -> int: return int(round(time.time() * 1000)) class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index fa15e4ff7d81..a8d5cf16bb07 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -206,8 +206,6 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - if args: - global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -247,7 +245,16 @@ def _has_database_account_header(request_headers): return True return False -def _handle_service_request_retries(client, request_retry_policy, exception, *args): +def _handle_service_request_retries( + client, + request_retry_policy, + exception, + start_time, + global_endpoint_manager, + client_timeout, + *args, + **kwargs +): # we resolve the request endpoint to the next preferred region # once we are out of preferred regions we stop retrying retry_policy = request_retry_policy @@ -255,6 +262,8 @@ def _handle_service_request_retries(client, request_retry_policy, exception, *ar if args and args[0].should_clear_session_token_on_session_read_failure and client.session: client.session.clear_session_token(client.last_response_headers) raise exception + else: + check_client_timeout(args, client_timeout, global_endpoint_manager, start_time, kwargs) def _handle_service_response_retries(request, client, response_retry_policy, exception, *args): if request and _has_read_retryable_headers(request.headers): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 82f09adbe4f5..93576224c2dc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -204,8 +204,6 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - if args: - await global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py index 6b56701738ad..b5660043b702 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py @@ -18,7 +18,7 @@ def setup(): config = test_config.TestConfig() use_multiple_write_locations = False - if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": use_multiple_write_locations = True if (config.masterKey == '[YOUR_KEY_HERE]' or config.host == '[YOUR_ENDPOINT_HERE]'): @@ -26,7 +26,8 @@ def setup(): "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - test_client = cosmos_client.CosmosClient(config.host, config.masterKey), + test_client = cosmos_client.CosmosClient(config.host, config.masterKey, + multiple_write_location=use_multiple_write_locations), return { "created_db": test_client[0].get_database_client(config.TEST_DATABASE_ID), "is_emulator": config.is_emulator diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py index 17b128927ef7..1ab899a6bd47 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed_async.py @@ -20,7 +20,7 @@ @pytest_asyncio.fixture() async def setup(): use_multiple_write_locations = False - if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": use_multiple_write_locations = True config = test_config.TestConfig() if config.masterKey == '[YOUR_KEY_HERE]' or config.host == '[YOUR_ENDPOINT_HERE]': diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index e87a28a80049..43f847ad45c5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -74,7 +74,7 @@ def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): @classmethod def setUpClass(cls): use_multiple_write_locations = False - if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": use_multiple_write_locations = True if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index d39c944da309..594f47e451f6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -80,7 +80,7 @@ def setUpClass(cls): async def asyncSetUp(self): use_multiple_write_locations = False - if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": use_multiple_write_locations = True self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=use_multiple_write_locations) self.database_for_test = self.client.get_database_client(self.configs.TEST_DATABASE_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index ae421b2ddb15..683f943e58cc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. import asyncio +import logging import os import unittest import uuid @@ -47,7 +48,7 @@ async def setup_teardown(): partition_key=PartitionKey("/pk"), offer_throughput=10000) # allow some time for the container to be created as this method is in different event loop - await asyncio.sleep(6) + await asyncio.sleep(10) yield await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index d1e6514bb06b..aa17116b2f39 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -34,7 +34,7 @@ class TestQuery(unittest.TestCase): @classmethod def setUpClass(cls): use_multiple_write_locations = False - if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": use_multiple_write_locations = True cls.client = cosmos_client.CosmosClient(cls.host, cls.credential, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index bfb35bd95b79..12b26ca363a7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -36,7 +36,7 @@ class TestQueryAsync(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): cls.use_multiple_write_locations = False - if os.environ.get("AZURE_COSMOS_ENABLE_MULTIPLE_WRITE_LOCATIONS", "False") == "True": + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": cls.use_multiple_write_locations = True if (cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): From bdf50257d1e05fc7cd5462f7905569d13e47ec1a Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Sat, 3 May 2025 17:03:20 -0400 Subject: [PATCH 135/152] remove unnecessary line --- .../azure-cosmos/azure/cosmos/_cosmos_client_connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 0292352090f2..7c5613100067 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2680,7 +2680,6 @@ def Create( request_params = RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers From fd755c0605e71c86af26660f123a06ba59e4f2f4 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Sun, 4 May 2025 22:24:32 -0400 Subject: [PATCH 136/152] fix tests --- .../tests/test_excluded_locations.py | 120 ++++------ .../tests/test_excluded_locations_async.py | 207 ++---------------- 2 files changed, 63 insertions(+), 264 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index b3288e2b2fd9..337a09bafd24 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -2,14 +2,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import logging +import re import unittest import uuid import test_config import pytest import azure.cosmos.cosmos_client as cosmos_client -from azure.cosmos.partition_key import PartitionKey -from azure.cosmos.exceptions import CosmosResourceNotFoundError +from azure.cosmos.http_constants import ResourceType class MockHandler(logging.Handler): @@ -102,66 +102,23 @@ def read_item_test_data(): all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data -def query_items_change_feed_test_data(): - client_only_output_data = [ - [L1, L1, L1, L1], #0 - [L2, L2, L2, L2], #1 - [L1, L1, L1, L1], #2 - [L1, L1, L1, L1] #3 - ] - client_and_request_output_data = [ - [L1, L1, L2, L2], #0 - [L2, L2, L2, L2], #1 - [L1, L1, L2, L2], #2 - [L2, L2, L1, L1], #3 - [L1, L1, L1, L1], #4 - [L2, L2, L1, L1], #5 - [L1, L1, L1, L1], #6 - [L1, L1, L1, L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def replace_item_test_data(): - client_only_output_data = [ - [L1, L1], #0 - [L2, L2], #1 - [L1, L0], #2 - [L1, L1] #3 - ] - client_and_request_output_data = [ - [L2, L2], #0 - [L2, L2], #1 - [L2, L2], #2 - [L1, L0], #3 - [L1, L0], #4 - [L1, L1], #5 - [L1, L1], #6 - [L1, L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data -def patch_item_test_data(): +def write_item_test_data(): client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #3 - [L1] #4 + [L1], # 0 + [L2], # 1 + [L0], # 2 + [L1], # 3 ] client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L0], # 3 + [L0], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 ] all_output_test_data = client_only_output_data + client_and_request_output_data @@ -212,16 +169,23 @@ def _verify_endpoint(messages, client, expected_locations): # get location actual_locations = set() for req_url in req_urls: - if req_url.startswith(default_endpoint): - actual_locations.add(L0) - else: - for endpoint in location_mapping: - if req_url.startswith(endpoint): - location = location_mapping[endpoint] - actual_locations.add(location) - break - - assert list(actual_locations) == list(set(expected_locations)) + print(req_url) + match = re.search(r"x\-ms\-thinclient\-proxy\-resource\-type': '([^']+)'", req_url) + if match: + resource_type = match.group(1) + # only check the document and partition key requests because that is where excluded locations + # is applicable + if resource_type in (ResourceType.Document, ResourceType.PartitionKey): + if req_url.startswith(default_endpoint): + actual_locations.add(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.add(location) + break + + assert actual_locations == set(expected_locations) @pytest.mark.cosmosMultiRegion class TestExcludedLocations: @@ -276,7 +240,7 @@ def test_query_items(self, test_data): # Verify endpoint locations _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) - @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) + @pytest.mark.parametrize('test_data', read_item_test_data()) def test_query_items_change_feed(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -295,7 +259,7 @@ def test_query_items_change_feed(self, test_data): _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) - @pytest.mark.parametrize('test_data', replace_item_test_data()) + @pytest.mark.parametrize('test_data', write_item_test_data()) def test_replace_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -314,9 +278,9 @@ def test_replace_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', replace_item_test_data()) + @pytest.mark.parametrize('test_data', write_item_test_data()) def test_upsert_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -336,9 +300,9 @@ def test_upsert_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', replace_item_test_data()) + @pytest.mark.parametrize('test_data', write_item_test_data()) def test_create_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -355,9 +319,9 @@ def test_create_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) + @pytest.mark.parametrize('test_data', write_item_test_data()) def test_patch_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -385,7 +349,7 @@ def test_patch_item(self, test_data): else: _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) + @pytest.mark.parametrize('test_data', write_item_test_data()) def test_execute_item_batch(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -414,7 +378,7 @@ def test_execute_item_batch(self, test_data): else: _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) + @pytest.mark.parametrize('test_data', write_item_test_data()) def test_delete_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 11ababfdfafd..adc5f8bc8d4a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -10,18 +10,7 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.partition_key import PartitionKey -from test_excluded_locations import _verify_endpoint - -class MockHandler(logging.Handler): - def __init__(self): - super(MockHandler, self).__init__() - self.messages = [] - - def reset(self): - self.messages = [] - - def emit(self, record): - self.messages.append(record.msg) +from test_excluded_locations import _verify_endpoint, MockHandler, read_item_test_data, write_item_test_data, L1 MOCK_HANDLER = MockHandler() CONFIG = test_config.TestConfig() @@ -34,160 +23,6 @@ def emit(self, record): ITEM_PK_VALUE = 'pk' TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} -L0 = "Default" -L1 = "West US 3" -L2 = "West US" -L3 = "East US 2" - -# L0 = "Default" -# L1 = "East US 2" -# L2 = "East US" -# L3 = "West US 2" - -CLIENT_ONLY_TEST_DATA = [ - # preferred_locations, client_excluded_locations, excluded_locations_request - # 0. No excluded location - [[L1, L2], [], None], - # 1. Single excluded location - [[L1, L2], [L1], None], - # 2. Exclude all locations - [[L1, L2], [L1, L2], None], - # 3. Exclude a location not in preferred locations - [[L1, L2], [L3], None], -] - -CLIENT_AND_REQUEST_TEST_DATA = [ - # preferred_locations, client_excluded_locations, excluded_locations_request - # 0. No client excluded locations + a request excluded location - [[L1, L2], [], [L1]], - # 1. The same client and request excluded location - [[L1, L2], [L1], [L1]], - # 2. Less request excluded locations - [[L1, L2], [L1, L2], [L1]], - # 3. More request excluded locations - [[L1, L2], [L1], [L1, L2]], - # 4. All locations were excluded - [[L1, L2], [L1, L2], [L1, L2]], - # 5. No common excluded locations - [[L1, L2], [L1], [L2, L3]], - # 6. Request excluded location not in preferred locations - [[L1, L2], [L1, L2], [L3]], - # 7. Empty excluded locations, remove all client excluded locations - [[L1, L2], [L1, L2], []], -] - -ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA - -def read_item_test_data(): - client_only_output_data = [ - [L1], # 0 - [L2], # 1 - [L1], # 2 - [L1], # 3 - ] - client_and_request_output_data = [ - [L2], # 0 - [L2], # 1 - [L2], # 2 - [L1], # 3 - [L1], # 4 - [L1], # 5 - [L1], # 6 - [L1], # 7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def read_all_item_test_data(): - client_only_output_data = [ - [L1, L1], # 0 - [L2, L2], # 1 - [L1, L1], # 2 - [L1, L1], # 3 - ] - client_and_request_output_data = [ - [L2, L2], # 0 - [L2, L2], # 1 - [L2, L2], # 2 - [L1, L1], # 3 - [L1, L1], # 4 - [L1, L1], # 5 - [L1, L1], # 6 - [L1, L1], # 7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def query_items_change_feed_test_data(): - client_only_output_data = [ - [L1, L1, L1, L1], #0 - [L2, L2, L2, L2], #1 - [L1, L1, L1, L1], #2 - [L1, L1, L1, L1] #3 - ] - client_and_request_output_data = [ - [L1, L2, L2, L2], #0 - [L2, L2, L2, L2], #1 - [L1, L2, L2, L2], #2 - [L2, L1, L1, L1], #3 - [L1, L1, L1, L1], #4 - [L2, L1, L1, L1], #5 - [L1, L1, L1, L1], #6 - [L1, L1, L1, L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def replace_item_test_data(): - client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #2 - [L1] #3 - ] - client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def patch_item_test_data(): - client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #2 - [L1] #3 - ] - client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - async def _create_item_with_excluded_locations(container, body, excluded_locations): if excluded_locations is None: await container.create_item(body=body) @@ -222,9 +57,9 @@ async def _init_container(preferred_locations, client_excluded_locations, multip @pytest.mark.cosmosMultiRegion @pytest.mark.asyncio @pytest.mark.usefixtures("setup_and_teardown") -class TestExcludedLocations: +class TestExcludedLocationsAsync: @pytest.mark.parametrize('test_data', read_item_test_data()) - async def test_read_item(self, test_data): + async def test_read_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -240,8 +75,8 @@ async def test_read_item(self, test_data): # Verify endpoint locations _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) - @pytest.mark.parametrize('test_data', read_all_item_test_data()) - async def test_read_all_items(self, test_data): + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_read_all_items_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -257,8 +92,8 @@ async def test_read_all_items(self, test_data): # Verify endpoint locations _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) - @pytest.mark.parametrize('test_data', read_all_item_test_data()) - async def test_query_items(self, test_data): + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_query_items_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -274,8 +109,8 @@ async def test_query_items(self, test_data): # Verify endpoint locations _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) - @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) - async def test_query_items_change_feed(self, test_data): + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_query_items_change_feed_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -293,8 +128,8 @@ async def test_query_items_change_feed(self, test_data): _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) - @pytest.mark.parametrize('test_data', replace_item_test_data()) - async def test_replace_item(self, test_data): + @pytest.mark.parametrize('test_data', write_item_test_data()) + async def test_replace_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -314,8 +149,8 @@ async def test_replace_item(self, test_data): else: _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', replace_item_test_data()) - async def test_upsert_item(self, test_data): + @pytest.mark.parametrize('test_data', write_item_test_data()) + async def test_upsert_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -336,8 +171,8 @@ async def test_upsert_item(self, test_data): else: _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', replace_item_test_data()) - async def test_create_item(self, test_data): + @pytest.mark.parametrize('test_data', write_item_test_data()) + async def test_create_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -355,8 +190,8 @@ async def test_create_item(self, test_data): else: _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) - async def test_patch_item(self, test_data): + @pytest.mark.parametrize('test_data', write_item_test_data()) + async def test_patch_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -383,8 +218,8 @@ async def test_patch_item(self, test_data): else: _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) - async def test_execute_item_batch(self, test_data): + @pytest.mark.parametrize('test_data', write_item_test_data()) + async def test_execute_item_batch_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -412,8 +247,8 @@ async def test_execute_item_batch(self, test_data): else: _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) - async def test_delete_item(self, test_data): + @pytest.mark.parametrize('test_data', write_item_test_data()) + async def test_delete_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data From c6ddb12521d46dd9353144eca06f5b25b5a8013a Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Sun, 4 May 2025 23:01:46 -0400 Subject: [PATCH 137/152] fix tests --- .../azure-cosmos/azure/cosmos/_retry_utility.py | 8 +------- .../tests/test_circuit_breaker_emulator.py | 13 ++----------- .../tests/test_circuit_breaker_emulator_async.py | 13 ++----------- 3 files changed, 5 insertions(+), 29 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index a8d5cf16bb07..0ee2cd1146cf 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -249,11 +249,7 @@ def _handle_service_request_retries( client, request_retry_policy, exception, - start_time, - global_endpoint_manager, - client_timeout, - *args, - **kwargs + *args ): # we resolve the request endpoint to the next preferred region # once we are out of preferred regions we stop retrying @@ -262,8 +258,6 @@ def _handle_service_request_retries( if args and args[0].should_clear_session_token_on_session_read_failure and client.session: client.session.clear_session_token(client.last_response_headers) raise exception - else: - check_client_timeout(args, client_timeout, global_endpoint_manager, start_time, kwargs) def _handle_service_response_retries(request, client, response_retry_policy, exception, *args): if request and _has_read_retryable_headers(request.headers): diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py index 3e450191fda2..0b8dff4460c9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py @@ -3,13 +3,12 @@ import os import unittest import uuid -from time import sleep import pytest from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _partition_health_tracker, documents +from azure.cosmos import _partition_health_tracker, documents from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport import FaultInjectionTransport @@ -24,15 +23,7 @@ @pytest.fixture(scope="class", autouse=True) def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" - client = CosmosClient(TestCircuitBreakerEmulator.host, TestCircuitBreakerEmulator.master_key) - created_database = client.get_database_client(TestCircuitBreakerEmulator.TEST_DATABASE_ID) - created_database.create_container(TestCircuitBreakerEmulator.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), - offer_throughput=10000) - # allow some time for the container to be created as this method is in different event loop - sleep(3) yield - created_database.delete_container(TestCircuitBreakerEmulator.TEST_CONTAINER_SINGLE_PARTITION_ID) os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" def create_custom_transport_mm(): @@ -56,7 +47,7 @@ class TestCircuitBreakerEmulator: master_key = test_config.TestConfig.masterKey connectionPolicy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py index b0096d63fe9d..5e63ba42e33b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py @@ -4,14 +4,13 @@ import os import unittest import uuid -from time import sleep import pytest import pytest_asyncio from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _partition_health_tracker, documents +from azure.cosmos import _partition_health_tracker, documents from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import ResourceType @@ -26,15 +25,7 @@ @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_teardown(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" - client = CosmosClient(TestCircuitBreakerEmulatorAsync.host, TestCircuitBreakerEmulatorAsync.master_key) - created_database = client.get_database_client(TestCircuitBreakerEmulatorAsync.TEST_DATABASE_ID) - await created_database.create_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), - offer_throughput=10000) - # allow some time for the container to be created as this method is in different event loop - sleep(3) yield - await created_database.delete_container(TestCircuitBreakerEmulatorAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" async def create_custom_transport_mm(): @@ -59,7 +50,7 @@ class TestCircuitBreakerEmulatorAsync: master_key = test_config.TestConfig.masterKey connectionPolicy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID async def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", From ba7ac08b54b583b5063a267ec2df26f89cb39b83 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 5 May 2025 09:22:56 -0400 Subject: [PATCH 138/152] fix tests --- .../azure/cosmos/_location_cache.py | 4 ++- .../azure-cosmos/tests/test_change_feed.py | 2 +- .../test_circuit_breaker_emulator_async.py | 6 +++- .../test_per_partition_circuit_breaker_mm.py | 29 ++++--------------- ..._per_partition_circuit_breaker_mm_async.py | 28 ++++-------------- ...st_per_partition_circuit_breaker_sm_mrr.py | 20 ++----------- ..._partition_circuit_breaker_sm_mrr_async.py | 22 +++----------- 7 files changed, 28 insertions(+), 83 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index d4af5281da9e..65ee347be079 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -173,7 +173,9 @@ def get_read_regional_routing_contexts(self): return self.read_regional_routing_contexts def get_location_from_endpoint(self, endpoint: str) -> str: - return self.account_locations_by_read_regional_routing_context[endpoint] + if endpoint in self.account_locations_by_read_regional_routing_context: + return self.account_locations_by_read_regional_routing_context[endpoint] + return self.account_write_locations[0] def get_write_regional_routing_context(self): return self.get_write_regional_routing_contexts()[0].get_primary() diff --git a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py index b5660043b702..94713d543003 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_change_feed.py +++ b/sdk/cosmos/azure-cosmos/tests/test_change_feed.py @@ -27,7 +27,7 @@ def setup(): "'masterKey' and 'host' at the top of this class to run the " "tests.") test_client = cosmos_client.CosmosClient(config.host, config.masterKey, - multiple_write_location=use_multiple_write_locations), + multiple_write_locations=use_multiple_write_locations), return { "created_db": test_client[0].get_database_client(config.TEST_DATABASE_ID), "is_emulator": config.is_emulator diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py index 5e63ba42e33b..3ec83a004bc9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py @@ -17,7 +17,7 @@ from test_per_partition_circuit_breaker_mm_async import (create_doc, PK_VALUE, create_errors, DELETE_ALL_ITEMS_BY_PARTITION_KEY, validate_unhealthy_partitions as validate_unhealthy_partitions_mm, - perform_write_operation) + perform_write_operation, cleanup_method) from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions as validate_unhealthy_partitions_sm_mrr from _fault_injection_transport_async import FaultInjectionTransportAsync @@ -120,6 +120,7 @@ async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_sm_asy ) validate_unhealthy_partitions_sm_mrr(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("error", create_errors()) @@ -175,6 +176,7 @@ async def test_write_consecutive_failure_threshold_delete_all_items_by_pk_mm_asy finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions_mm(global_endpoint_manager, 0) + await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("error", create_errors()) @@ -217,6 +219,7 @@ async def test_write_failure_rate_threshold_delete_all_items_by_pk_mm_async(self os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("error", create_errors()) async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm_async(self, setup_teardown, error): @@ -258,6 +261,7 @@ async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm_async(self os.environ["AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED"] = "90" # restore minimum requests _partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + await cleanup_method([custom_setup, setup]) # test cosmos client timeout diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index ad3208a2d03b..756427d6162d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -8,8 +8,7 @@ from azure.core.exceptions import ServiceResponseError import test_config -from time import sleep -from azure.cosmos import PartitionKey, _location_cache, _partition_health_tracker +from azure.cosmos import _location_cache, _partition_health_tracker from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport import FaultInjectionTransport @@ -19,21 +18,6 @@ create_doc from test_per_partition_circuit_breaker_mm_async import DELETE_ALL_ITEMS_BY_PARTITION_KEY - -@pytest.fixture(scope="class", autouse=True) -def setup_teardown(): - client = CosmosClient(TestPerPartitionCircuitBreakerMM.host, - TestPerPartitionCircuitBreakerMM.master_key) - created_database = client.get_database_client(TestPerPartitionCircuitBreakerMM.TEST_DATABASE_ID) - created_database.create_container(TestPerPartitionCircuitBreakerMM.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), - offer_throughput=10000) - # allow some time for the container to be created as this method is in different event loop - sleep(6) - yield - - created_database.delete_container(TestPerPartitionCircuitBreakerMM.TEST_CONTAINER_SINGLE_PARTITION_ID) - def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): doc = {'id': doc_id, 'pk': pk, @@ -107,12 +91,11 @@ def perform_read_operation(operation, container, doc_id, pk, expected_uri): pass @pytest.mark.cosmosCircuitBreaker -@pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMM: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, @@ -124,7 +107,7 @@ def setup_method_with_custom_transport(self, custom_transport, default_endpoint= return {"client": client, "db": db, "col": container} @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - def test_write_consecutive_failure_threshold(self, setup_teardown, write_operation, error): + def test_write_consecutive_failure_threshold(self, write_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay( 0, error @@ -180,7 +163,7 @@ def test_write_consecutive_failure_threshold(self, setup_teardown, write_operati @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation, error): + def test_read_consecutive_failure_threshold(self, read_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay( 0, error @@ -230,7 +213,7 @@ def test_read_consecutive_failure_threshold(self, setup_teardown, read_operation validate_unhealthy_partitions(global_endpoint_manager, 0) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - def test_write_failure_rate_threshold(self, setup_teardown, write_operation, error): + def test_write_failure_rate_threshold(self, write_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay( 0, error @@ -273,7 +256,7 @@ def test_write_failure_rate_threshold(self, setup_teardown, write_operation, err @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - def test_read_failure_rate_threshold(self, setup_teardown, read_operation, error): + def test_read_failure_rate_threshold(self, read_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay( 0, error diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 683f943e58cc..169a4bbc92c4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -1,19 +1,17 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. import asyncio -import logging import os import unittest import uuid from typing import Dict, Any, List import pytest -import pytest_asyncio from azure.core.pipeline.transport._aiohttp import AioHttpTransport from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _location_cache, _partition_health_tracker +from azure.cosmos import _location_cache, _partition_health_tracker from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError @@ -39,19 +37,6 @@ COLLECTION = "created_collection" -@pytest_asyncio.fixture(scope="class", autouse=True) -async def setup_teardown(): - client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, - TestPerPartitionCircuitBreakerMMAsync.master_key) - created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) - await created_database.create_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), - offer_throughput=10000) - # allow some time for the container to be created as this method is in different event loop - await asyncio.sleep(10) - yield - await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) - await client.close() def create_errors(): errors = [] @@ -196,12 +181,11 @@ async def cleanup_method(initialized_objects: List[Dict[str, Any]]): @pytest.mark.cosmosCircuitBreaker @pytest.mark.asyncio -@pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerMMAsync: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, @@ -213,7 +197,7 @@ async def setup_method_with_custom_transport(self, custom_transport: AioHttpTran return {"client": client, "db": db, "col": container} @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): + async def test_write_consecutive_failure_threshold_async(self, write_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error @@ -287,7 +271,7 @@ async def setup_info(self, error): @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - async def test_read_consecutive_failure_threshold_async(self, setup_teardown, read_operation, error): + async def test_read_consecutive_failure_threshold_async(self, read_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error @@ -340,7 +324,7 @@ async def test_read_consecutive_failure_threshold_async(self, setup_teardown, re await cleanup_method([custom_setup, setup]) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_write_failure_rate_threshold_async(self, setup_teardown, write_operation, error): + async def test_write_failure_rate_threshold_async(self, write_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error @@ -386,7 +370,7 @@ async def test_write_failure_rate_threshold_async(self, setup_teardown, write_op @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) - async def test_read_failure_rate_threshold_async(self, setup_teardown, read_operation, error): + async def test_read_failure_rate_threshold_async(self, read_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index 89bf93266ec9..336f5d304889 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -3,14 +3,12 @@ import os import unittest import uuid -from time import sleep import pytest -import pytest_asyncio from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _partition_health_tracker, _location_cache +from azure.cosmos import _partition_health_tracker, _location_cache from azure.cosmos import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError from _fault_injection_transport import FaultInjectionTransport @@ -20,20 +18,8 @@ from test_per_partition_circuit_breaker_sm_mrr_async import validate_unhealthy_partitions COLLECTION = "created_collection" -@pytest_asyncio.fixture(scope="class", autouse=True) -def setup_teardown(): - client = CosmosClient(TestPerPartitionCircuitBreakerSmMrr.host, TestPerPartitionCircuitBreakerSmMrr.master_key) - created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrr.TEST_DATABASE_ID) - created_database.create_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), - offer_throughput=10000) - # allow some time for the container to be created as this method is in different event loop - sleep(6) - yield - created_database.delete_container(TestPerPartitionCircuitBreakerSmMrr.TEST_CONTAINER_SINGLE_PARTITION_ID) @pytest.mark.cosmosCircuitBreakerMultiRegion -@pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrr: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey @@ -64,7 +50,7 @@ def setup_info(self, error): return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - def test_write_consecutive_failure_threshold(self, setup_teardown, write_operation, error): + def test_write_consecutive_failure_threshold(self, write_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) container = setup['col'] @@ -87,7 +73,7 @@ def test_write_consecutive_failure_threshold(self, setup_teardown, write_operati validate_unhealthy_partitions(global_endpoint_manager, 0) @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - def test_write_failure_rate_threshold(self, setup_teardown, write_operation, error): + def test_write_failure_rate_threshold(self, write_operation, error): error_lambda = lambda r: FaultInjectionTransport.error_after_delay(0, error) setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate = self.setup_info(error_lambda) container = setup['col'] diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index b7a5fd7de349..cb4fdc128c75 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -7,12 +7,11 @@ from typing import Dict, Any import pytest -import pytest_asyncio from azure.core.pipeline.transport._aiohttp import AioHttpTransport from azure.core.exceptions import ServiceResponseError import test_config -from azure.cosmos import PartitionKey, _partition_health_tracker, _location_cache +from azure.cosmos import _partition_health_tracker, _location_cache from azure.cosmos._partition_health_tracker import UNHEALTHY_TENTATIVE, UNHEALTHY, HEALTH_STATUS from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError @@ -21,18 +20,6 @@ write_operations_and_errors, cleanup_method, perform_read_operation, operations, REGION_2, REGION_1 COLLECTION = "created_collection" -@pytest_asyncio.fixture(scope="class", autouse=True) -async def setup_teardown(): - client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key) - created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) - await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), - offer_throughput=10000) - # allow some time for the container to be created as this method is in different event loop - await asyncio.sleep(6) - yield - await created_database.delete_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) - await client.close() def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): @@ -55,13 +42,12 @@ def validate_unhealthy_partitions(global_endpoint_manager, @pytest.mark.cosmosCircuitBreakerMultiRegion @pytest.mark.asyncio -@pytest.mark.usefixtures("setup_teardown") class TestPerPartitionCircuitBreakerSmMrrAsync: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey connectionPolicy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", @@ -91,7 +77,7 @@ async def setup_info(self, error): return setup, doc, expected_uri, uri_down, custom_setup, custom_transport, predicate @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_write_consecutive_failure_threshold_async(self, setup_teardown, write_operation, error): + async def test_write_consecutive_failure_threshold_async(self, write_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error @@ -119,7 +105,7 @@ async def test_write_consecutive_failure_threshold_async(self, setup_teardown, w @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) - async def test_write_failure_rate_threshold_async(self, setup_teardown, write_operation, error): + async def test_write_failure_rate_threshold_async(self, write_operation, error): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, error From afd8b610efdeabad16b2d81f046b72a517b997f1 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 5 May 2025 11:18:01 -0400 Subject: [PATCH 139/152] fix tests --- .../azure/cosmos/aio/_global_endpoint_manager_async.py | 3 +++ .../tests/test_per_partition_circuit_breaker_mm.py | 4 ++-- .../tests/test_per_partition_circuit_breaker_mm_async.py | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 5a7aea5330ba..051c94bcf3a0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -87,6 +87,9 @@ def mark_endpoint_unavailable_for_write(self, endpoint, refresh_cache): def get_ordered_write_locations(self): return self.location_cache.get_ordered_write_locations() + def get_ordered_read_locations(self): + return self.location_cache.get_ordered_read_locations() + def can_use_multiple_write_locations(self, request): return self.location_cache.can_use_multiple_write_locations_for_request(request) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index 756427d6162d..a1070850dd77 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -194,7 +194,7 @@ def test_read_consecutive_failure_threshold(self, read_operation, error): # the partition should have been marked as unavailable after breaking read threshold if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 + expected_unhealthy_partitions = 5 else: expected_unhealthy_partitions = 1 validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) @@ -283,7 +283,7 @@ def test_read_failure_rate_threshold(self, read_operation, error): expected_uri) if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 + expected_unhealthy_partitions = 5 else: expected_unhealthy_partitions = 1 diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 169a4bbc92c4..f24397ff8fe6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -304,7 +304,7 @@ async def test_read_consecutive_failure_threshold_async(self, read_operation, er # the partition should have been marked as unavailable after breaking read threshold if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 + expected_unhealthy_partitions = 5 else: expected_unhealthy_partitions = 1 validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions) @@ -399,7 +399,7 @@ async def test_read_failure_rate_threshold_async(self, read_operation, error): expected_uri) if read_operation in (CHANGE_FEED, QUERY, READ_ALL_ITEMS): # these operations are cross partition so they would mark both partitions as unavailable - expected_unhealthy_partitions = 2 + expected_unhealthy_partitions = 5 else: expected_unhealthy_partitions = 1 From a9e449e96a531ebfdaa9dfa5e747d48d5512f4e4 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 5 May 2025 17:42:57 -0400 Subject: [PATCH 140/152] fix tests --- sdk/cosmos/azure-cosmos/tests/test_crud_async.py | 2 ++ sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py | 1 - .../tests/test_per_partition_circuit_breaker_mm.py | 4 ++++ .../tests/test_per_partition_circuit_breaker_mm_async.py | 3 +++ sdk/cosmos/azure-cosmos/tests/test_query_async.py | 4 +++- sdk/cosmos/live-platform-matrix.json | 2 +- 6 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 594f47e451f6..40f4ff1d422c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -4,6 +4,7 @@ """End-to-end test. """ +import asyncio import os import time import unittest @@ -1000,6 +1001,7 @@ async def test_query_iterable_functionality_async(self): doc1 = await collection.upsert_item(body={'id': 'doc1', 'prop1': 'value1'}) doc2 = await collection.upsert_item(body={'id': 'doc2', 'prop1': 'value2'}) doc3 = await collection.upsert_item(body={'id': 'doc3', 'prop1': 'value3'}) + await asyncio.sleep(1) resources = { 'coll': collection, 'doc1': doc1, diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 337a09bafd24..d467e0ac1d64 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -169,7 +169,6 @@ def _verify_endpoint(messages, client, expected_locations): # get location actual_locations = set() for req_url in req_urls: - print(req_url) match = re.search(r"x\-ms\-thinclient\-proxy\-resource\-type': '([^']+)'", req_url) if match: resource_type = match.group(1) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index a1070850dd77..aa05f31f6282 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -3,6 +3,7 @@ import os import unittest import uuid +from time import sleep import pytest from azure.core.exceptions import ServiceResponseError @@ -29,6 +30,7 @@ def perform_write_operation(operation, container, fault_injection_container, doc resp = fault_injection_container.upsert_item(body=doc) elif operation == REPLACE: container.create_item(body=doc) + sleep(1) new_doc = {'id': doc_id, 'pk': pk, 'name': 'sample document' + str(uuid), @@ -36,9 +38,11 @@ def perform_write_operation(operation, container, fault_injection_container, doc resp = fault_injection_container.replace_item(item=doc['id'], body=new_doc) elif operation == DELETE: container.create_item(body=doc) + sleep(1) resp = fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) elif operation == PATCH: container.create_item(body=doc) + sleep(1) operations = [{"op": "incr", "path": "/company", "value": 3}] resp = fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) elif operation == BATCH: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index f24397ff8fe6..2e038e525874 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -102,12 +102,15 @@ async def perform_write_operation(operation, container, fault_injection_containe 'pk': pk, 'name': 'sample document' + str(uuid), 'key': 'value'} + await asyncio.sleep(1) resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) elif operation == DELETE: await container.create_item(body=doc) + await asyncio.sleep(1) resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) elif operation == PATCH: await container.create_item(body=doc) + await asyncio.sleep(1) operations = [{"op": "incr", "path": "/company", "value": 3}] resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) elif operation == BATCH: diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index 12b26ca363a7..c518cf899a3a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import asyncio import os import unittest import uuid @@ -60,6 +60,7 @@ async def test_first_and_last_slashes_trimmed_for_query_string_async(self): doc_id = 'myId' + str(uuid.uuid4()) document_definition = {'pk': 'pk', 'id': doc_id} await created_collection.create_item(body=document_definition) + await asyncio.sleep(1) query = 'SELECT * from c' query_iterable = created_collection.query_items( @@ -107,6 +108,7 @@ async def test_populate_index_metrics_async(self): doc_id = 'MyId' + str(uuid.uuid4()) document_definition = {'pk': 'pk', 'id': doc_id} await created_collection.create_item(body=document_definition) + await asyncio.sleep(1) query = 'SELECT * from c' query_iterable = created_collection.query_items( diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index c4fe1426b1c6..dffc939e6379 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -120,7 +120,7 @@ "Windows2022_38": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", - "PythonVersion": "3.8", + "PythonVersion": "3.9", "CoverageArg": "--disablecov", "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion" From bf6a7c202ef7a186352f9b0c0a9a796327f0aa93 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 5 May 2025 19:36:03 -0400 Subject: [PATCH 141/152] fix tests --- .../tests/test_per_partition_circuit_breaker_sm_mrr.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_query_async.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index 336f5d304889..2147b8c037ec 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -25,7 +25,7 @@ class TestPerPartitionCircuitBreakerSmMrr: master_key = test_config.TestConfig.masterKey connectionPolicy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID def setup_method_with_custom_transport(self, custom_transport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index c518cf899a3a..11cbc8768fde 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -80,6 +80,7 @@ async def test_populate_query_metrics_async(self): doc_id = 'MyId' + str(uuid.uuid4()) document_definition = {'pk': 'pk', 'id': doc_id} await created_collection.create_item(body=document_definition) + await asyncio.sleep(1) query = 'SELECT * from c' query_iterable = created_collection.query_items( From c46cdc702dbe79c3f524b1e15d61114c50a69ef8 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 5 May 2025 22:52:35 -0400 Subject: [PATCH 142/152] fix tests --- .../azure/cosmos/_global_endpoint_manager.py | 4 +- ...tition_endpoint_manager_circuit_breaker.py | 37 +++++++++++------ ...n_endpoint_manager_circuit_breaker_core.py | 9 +++-- .../azure/cosmos/_partition_health_tracker.py | 2 +- .../aio/_cosmos_client_connection_async.py | 7 ---- .../aio/_global_endpoint_manager_async.py | 4 +- ..._endpoint_manager_circuit_breaker_async.py | 40 +++++++++++++------ .../tests/_fault_injection_transport.py | 1 - .../tests/_fault_injection_transport_async.py | 1 - .../tests/test_circuit_breaker_emulator.py | 2 +- .../test_circuit_breaker_emulator_async.py | 2 +- 11 files changed, 65 insertions(+), 44 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index f61bd37feba7..e660a608ca5a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -24,7 +24,7 @@ """ import threading -from typing import Tuple +from typing import Tuple, Optional from azure.core.exceptions import AzureError @@ -71,7 +71,7 @@ def get_read_endpoint(self): def resolve_service_endpoint( self, request: RequestObject, - pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] # pylint: disable=unused-argument ) -> str: return self.location_cache.resolve_service_endpoint(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 5a340eaf6f6c..d2ea988bcb49 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -21,7 +21,8 @@ """Internal class for global endpoint manager for circuit breaker. """ -from typing import TYPE_CHECKING +import logging +from typing import TYPE_CHECKING, Optional from azure.cosmos.partition_key import PartitionKey from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ @@ -35,6 +36,8 @@ if TYPE_CHECKING: from azure.cosmos._cosmos_client_connection import CosmosClientConnection +logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForCircuitBreaker") + class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): """ This internal class implements the logic for partition endpoint management for @@ -51,28 +54,35 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) - def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionKeyRangeWrapper]: + if HttpHeaders.IntendedCollectionRID in request.headers: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + else: + logger.warning("Illegal state: the request does not contain container information. " + "Circuit breaker cannot be performed.") + return None properties = self.Client._container_properties_cache[container_rid] # pylint: disable=protected-access # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] partition_key_definition = properties["partitionKey"] partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) - if request.headers.get(HttpHeaders.PartitionKey): + if HttpHeaders.PartitionKey in request.headers: partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] # pylint: disable=protected-access partition_ranges = (self.Client._routing_map_provider # pylint: disable=protected-access .get_overlapping_ranges(container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) - elif request.headers.get(HttpHeaders.PartitionKeyRangeID): + elif HttpHeaders.PartitionKeyRangeID in request.headers: pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] epk_range =(self.Client._routing_map_provider # pylint: disable=protected-access .get_range_by_partition_key_range_id(container_link, pk_range_id)) partition_range = Range.PartitionKeyRangeToRange(epk_range) else: - raise RuntimeError("Illegal state: the request does not contain partition information.") + logger.warning("Illegal state: the request does not contain partition information. " + "Circuit breaker cannot be performed.") + return None return PartitionKeyRangeWrapper(partition_range, container_rid) @@ -83,10 +93,11 @@ def record_failure( ) -> None: if self.is_circuit_breaker_applicable(request): pk_range_wrapper = self.create_pk_range_wrapper(request) - self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) - def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): - if self.is_circuit_breaker_applicable(request): + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: Optional[PartitionKeyRangeWrapper]): + if self.is_circuit_breaker_applicable(request) and pk_range_wrapper: self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) @@ -96,9 +107,10 @@ def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: Par def mark_partition_unavailable( self, request: RequestObject, - pk_range_wrapper: PartitionKeyRangeWrapper + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] ) -> None: - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) def record_success( self, @@ -106,4 +118,5 @@ def record_success( ) -> None: if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): pk_range_wrapper = self.create_pk_range_wrapper(request) - self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 548430dd6806..db4215ca694f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -21,6 +21,7 @@ """Internal class for global endpoint manager for circuit breaker. """ +import logging import os from azure.cosmos import documents @@ -32,6 +33,7 @@ from azure.cosmos.http_constants import ResourceType, HttpHeaders from azure.cosmos._constants import _Constants as Constants +logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForCircuitBreakerCore") class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): """ @@ -39,7 +41,6 @@ class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): geo-replicated database accounts. """ - def __init__(self, client, location_cache: LocationCache): self.partition_health_tracker = _PartitionHealthTracker() self.location_cache = location_cache @@ -63,8 +64,8 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: return False # this is for certain cross partition queries and read all items where we cannot discern partition information - if (not request.headers.get(HttpHeaders.PartitionKeyRangeID) - and not request.headers.get(HttpHeaders.PartitionKey)): + if (HttpHeaders.PartitionKeyRangeID not in request.headers + and HttpHeaders.PartitionKey not in request.headers): return False return True @@ -95,7 +96,7 @@ def add_excluded_locations_to_request( pk_range_wrapper: PartitionKeyRangeWrapper ) -> RequestObject: request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(request, pk_range_wrapper) + self.partition_health_tracker.get_unhealthy_locations(request, pk_range_wrapper) ) return request diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index d086d9d2b497..a8400bf552d6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -191,7 +191,7 @@ def check_stale_partition_info( self._reset_partition_health_tracker_stats() - def get_excluded_locations( + def get_unhealthy_locations( self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index c8bc8d8582af..3e19f0ddc000 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -779,7 +779,6 @@ async def Create( request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -920,7 +919,6 @@ async def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1223,7 +1221,6 @@ async def Read( # Read will use ReadEndpoint since it uses GET operation request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1483,7 +1480,6 @@ async def PatchItem( # Patch will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1589,7 +1585,6 @@ async def Replace( # Replace will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1914,7 +1909,6 @@ async def DeleteResource( # Delete will use WriteEndpoint since it uses DELETE operation request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2029,7 +2023,6 @@ async def _Batch( documents._OperationType.Batch, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 051c94bcf3a0..0d51aa83512a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,7 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from typing import Tuple, Dict, Any +from typing import Tuple, Dict, Any, Optional from azure.core.exceptions import AzureError from azure.cosmos import DatabaseAccount @@ -74,7 +74,7 @@ def get_read_endpoint(self): def resolve_service_endpoint( self, request: RequestObject, - pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] # pylint: disable=unused-argument ) -> str: return self.location_cache.resolve_service_endpoint(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 7a6a06c3e8c7..92705ba39907 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -21,7 +21,8 @@ """Internal class for global endpoint manager for circuit breaker. """ -from typing import TYPE_CHECKING +import logging +from typing import TYPE_CHECKING, Optional from azure.cosmos import PartitionKey from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ @@ -35,6 +36,7 @@ if TYPE_CHECKING: from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection +logger = logging.getLogger("azure.cosmos.aio._GlobalPartitionEndpointManagerForCircuitBreakerAsync") # pylint: disable=protected-access class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): @@ -49,28 +51,35 @@ def __init__(self, client: "CosmosClientConnection"): self.global_partition_endpoint_manager_core = ( _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) - async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + async def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionKeyRangeWrapper]: + if HttpHeaders.IntendedCollectionRID in request.headers: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + else: + logger.warning("Illegal state: the request does not contain container information. " + "Circuit breaker cannot be performed.") + return None properties = self.client._container_properties_cache[container_rid] # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] partition_key_definition = properties["partitionKey"] partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) - if request.headers.get(HttpHeaders.PartitionKey): + if HttpHeaders.PartitionKey in request.headers: partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] partition_ranges = await (self.client._routing_map_provider .get_overlapping_ranges(container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) - elif request.headers.get(HttpHeaders.PartitionKeyRangeID): + elif HttpHeaders.PartitionKeyRangeID in request.headers: pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] epk_range = await (self.client._routing_map_provider .get_range_by_partition_key_range_id(container_link, pk_range_id)) partition_range = Range.PartitionKeyRangeToRange(epk_range) else: - raise RuntimeError("Illegal state: the request does not contain partition information.") + logger.warning("Illegal state: the request does not contain partition information. " + "Circuit breaker cannot be performed.") + return None return PartitionKeyRangeWrapper(partition_range, container_rid) @@ -83,10 +92,15 @@ async def record_failure( ) -> None: if self.is_circuit_breaker_applicable(request): pk_range_wrapper = await self.create_pk_range_wrapper(request) - self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) - def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): - if self.is_circuit_breaker_applicable(request): + def resolve_service_endpoint( + self, + request: RequestObject, + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] + ): + if self.is_circuit_breaker_applicable(request) and pk_range_wrapper: self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) @@ -96,9 +110,10 @@ def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: Par def mark_partition_unavailable( self, request: RequestObject, - pk_range_wrapper: PartitionKeyRangeWrapper + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] ) -> None: - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) async def record_success( self, @@ -106,4 +121,5 @@ async def record_success( ) -> None: if self.is_circuit_breaker_applicable(request): pk_range_wrapper = await self.create_pk_range_wrapper(request) - self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) + if pk_range_wrapper: + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index f745f3247b5d..7d5df4c69a72 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -25,7 +25,6 @@ import json import logging import sys -from importlib.resources import is_resource from time import sleep from typing import Callable, Optional, Any, Dict, List, MutableMapping diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index ce10504164bb..eed8aaabd2f0 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -26,7 +26,6 @@ import json import logging import sys -from importlib.resources import is_resource from typing import Callable, Optional, Any, Dict, List, Awaitable, MutableMapping import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py index 0b8dff4460c9..30448b2aa623 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator.py @@ -260,4 +260,4 @@ def test_write_failure_rate_threshold_delete_all_items_by_pk_sm(self, setup_tear # test cosmos client timeout if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py index 3ec83a004bc9..23315499c2da 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_circuit_breaker_emulator_async.py @@ -266,4 +266,4 @@ async def test_write_failure_rate_threshold_delete_all_items_by_pk_sm_async(self # test cosmos client timeout if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From 61b15e9527ddc93b3bca6b655ad01ac12747a39f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 8 May 2025 09:21:44 -0400 Subject: [PATCH 143/152] remove unused logger --- .../_global_partition_endpoint_manager_circuit_breaker_core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index db4215ca694f..f4a850442d0a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -21,7 +21,6 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging import os from azure.cosmos import documents @@ -33,7 +32,6 @@ from azure.cosmos.http_constants import ResourceType, HttpHeaders from azure.cosmos._constants import _Constants as Constants -logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForCircuitBreakerCore") class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): """ From d0feac2fc60fe28ac3df29dd7f2b6bf5563ef7c1 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 12 May 2025 12:37:56 -0400 Subject: [PATCH 144/152] add regression testing for cross partition queries --- .../azure-cosmos/tests/test_query_cross_partition.py | 7 +++++-- .../azure-cosmos/tests/test_query_cross_partition_async.py | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py index 2e518d075c73..1ee7b550f13b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition.py @@ -16,7 +16,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey - +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestCrossPartitionQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -39,7 +39,10 @@ def setUpClass(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True + cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, multiple_write_locations=use_multiple_write_locations) cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) if cls.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py index d9a5b4b251b7..9a6961d1cd03 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_cross_partition_async.py @@ -16,6 +16,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.partition_key import PartitionKey +@pytest.mark.cosmosCircuitBreaker @pytest.mark.cosmosQuery class TestQueryCrossPartitionAsync(unittest.IsolatedAsyncioTestCase): """Test to ensure escaping of non-ascii characters from partition key""" @@ -40,7 +41,10 @@ def setUpClass(cls): "tests.") async def asyncSetUp(self): - self.client = CosmosClient(self.host, self.masterKey) + use_multiple_write_locations = False + if os.environ.get("AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER", "False") == "True": + use_multiple_write_locations = True + self.client = CosmosClient(self.host, self.masterKey, multiple_write_locations=use_multiple_write_locations) self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) self.created_container = await self.created_db.create_container( self.TEST_CONTAINER_ID, From 9c34f41051b368030d6ff8a6cd340d606adca2de Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 12 May 2025 18:38:58 -0400 Subject: [PATCH 145/152] react to comments --- .../azure/cosmos/_global_endpoint_manager.py | 8 +++--- ...tition_endpoint_manager_circuit_breaker.py | 27 ++++++++++++------- ...n_endpoint_manager_circuit_breaker_core.py | 11 ++++++++ .../_routing/aio/routing_map_provider.py | 9 +++---- .../cosmos/_routing/routing_map_provider.py | 7 +++-- .../cosmos/_service_request_retry_policy.py | 4 +-- .../cosmos/_service_response_retry_policy.py | 6 ++--- .../azure/cosmos/_session_retry_policy.py | 12 ++++----- .../azure/cosmos/_synchronized_request.py | 2 +- .../cosmos/_timeout_failover_retry_policy.py | 2 +- .../azure/cosmos/aio/_asynchronous_request.py | 2 +- .../aio/_cosmos_client_connection_async.py | 8 +++--- .../aio/_global_endpoint_manager_async.py | 8 +++--- ..._endpoint_manager_circuit_breaker_async.py | 23 +++++++++------- .../azure-cosmos/tests/test_globaldb_mock.py | 2 +- ..._per_partition_circuit_breaker_mm_async.py | 1 + .../tests/test_streaming_failover.py | 4 +-- 17 files changed, 76 insertions(+), 60 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index e660a608ca5a..c46f883e3703 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -24,14 +24,13 @@ """ import threading -from typing import Tuple, Optional +from typing import Tuple from azure.core.exceptions import AzureError from . import _constants as constants from . import exceptions from ._request_object import RequestObject -from ._routing.routing_range import PartitionKeyRangeWrapper from .documents import DatabaseAccount from ._location_cache import LocationCache, current_time_millis @@ -68,10 +67,9 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint( + def _resolve_service_endpoint( self, - request: RequestObject, - pk_range_wrapper: Optional[PartitionKeyRangeWrapper] # pylint: disable=unused-argument + request: RequestObject ) -> str: return self.location_cache.resolve_service_endpoint(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index d2ea988bcb49..bd980a038f61 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -21,7 +21,6 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging from typing import TYPE_CHECKING, Optional from azure.cosmos.partition_key import PartitionKey @@ -36,8 +35,6 @@ if TYPE_CHECKING: from azure.cosmos._cosmos_client_connection import CosmosClientConnection -logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForCircuitBreaker") - class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): """ This internal class implements the logic for partition endpoint management for @@ -58,8 +55,9 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK if HttpHeaders.IntendedCollectionRID in request.headers: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] else: - logger.warning("Illegal state: the request does not contain container information. " - "Circuit breaker cannot be performed.") + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain container information. " + "Circuit breaker cannot be performed.") return None properties = self.Client._container_properties_cache[container_rid] # pylint: disable=protected-access # get relevant information from container cache to get the overlapping ranges @@ -78,10 +76,16 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] epk_range =(self.Client._routing_map_provider # pylint: disable=protected-access .get_range_by_partition_key_range_id(container_link, pk_range_id)) + if not epk_range: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: partition key range cache not initialized correctly. " + "Circuit breaker cannot be performed.") + return None partition_range = Range.PartitionKeyRangeToRange(epk_range) else: - logger.warning("Illegal state: the request does not contain partition information. " - "Circuit breaker cannot be performed.") + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain partition information. " + "Circuit breaker cannot be performed.") return None return PartitionKeyRangeWrapper(partition_range, container_rid) @@ -96,13 +100,16 @@ def record_failure( if pk_range_wrapper: self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) - def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: Optional[PartitionKeyRangeWrapper]): + def resolve_service_endpoint_for_partition( + self, + request: RequestObject, + pk_range_wrapper: Optional[PartitionKeyRangeWrapper] + ) -> str: if self.is_circuit_breaker_applicable(request) and pk_range_wrapper: self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) - return (super(_GlobalPartitionEndpointManagerForCircuitBreaker, self) - .resolve_service_endpoint(request, pk_range_wrapper)) + return self._resolve_service_endpoint(request) def mark_partition_unavailable( self, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index f4a850442d0a..a950b5557f35 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -21,6 +21,7 @@ """Internal class for global endpoint manager for circuit breaker. """ +import logging import os from azure.cosmos import documents @@ -32,6 +33,8 @@ from azure.cosmos.http_constants import ResourceType, HttpHeaders from azure.cosmos._constants import _Constants as Constants +logger = logging.getLogger("azure.cosmos._GlobalPartitionEndpointManagerForCircuitBreakerCore") +WARN_LEVEL_LOGGING_THRESHOLD = 10 class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): """ @@ -43,6 +46,14 @@ def __init__(self, client, location_cache: LocationCache): self.partition_health_tracker = _PartitionHealthTracker() self.location_cache = location_cache self.client = client + self.log_count = 0 + + def log_warn_or_debug(self, message: str) -> None: + self.log_count += 1 + if self.log_count >= WARN_LEVEL_LOGGING_THRESHOLD: + logger.debug(message) + else: + logger.warning(message) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: if not request: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index 680191ac2228..81399da7090b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -22,13 +22,12 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ -from typing import Dict, Any +from typing import Dict, Any, Optional from ... import _base from ..collection_routing_map import CollectionRoutingMap from .. import routing_range - # pylint: disable=protected-access @@ -71,10 +70,10 @@ async def init_collection_routing_map_if_needed( collection_id: str, **kwargs: Dict[str, Any] ): - client = self._documentClient collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: - collection_pk_ranges = [pk async for pk in client._ReadPartitionKeyRanges(collection_link, **kwargs)] + collection_pk_ranges = [pk async for pk in + self._documentClient._ReadPartitionKeyRanges(collection_link, **kwargs)] # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. @@ -89,7 +88,7 @@ async def get_range_by_partition_key_range_id( collection_link: str, partition_key_range_id: int, **kwargs: Dict[str, Any] - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) await self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index b22f05acb6b0..f1091fd629da 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -22,7 +22,7 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ -from typing import Dict, Any +from typing import Dict, Any, Optional from .. import _base from .collection_routing_map import CollectionRoutingMap @@ -57,10 +57,9 @@ def init_collection_routing_map_if_needed( collection_id: str, **kwargs: Dict[str, Any] ): - client = self._documentClient collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if not collection_routing_map: - collection_pk_ranges = list(client._ReadPartitionKeyRanges(collection_link, **kwargs)) + collection_pk_ranges = list(self._documentClient._ReadPartitionKeyRanges(collection_link, **kwargs)) # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. @@ -89,7 +88,7 @@ def get_range_by_partition_key_range_id( collection_link: str, partition_key_range_id: int, **kwargs: Dict[str, Any] - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) self.init_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index b49185512fa4..a0853c0b1065 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -100,7 +100,7 @@ def resolve_current_region_service_endpoint(self): # resolve the next service endpoint in the same region # since we maintain 2 endpoints per region for write operations self.request.route_to_location_with_preferred_location_flag(0, True) - return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) # This function prepares the request to go to the next region def resolve_next_region_service_endpoint(self): @@ -114,7 +114,7 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(0, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) def mark_endpoint_unavailable(self, unavailable_endpoint, refresh_cache: bool): if _OperationType.IsReadOnlyOperation(self.request.operation_type): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py index 330ffb5929a5..59fca57e1c76 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py @@ -21,8 +21,8 @@ def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, self.connection_policy = connection_policy self.request = args[0] if args else None if self.request: - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, - pk_range_wrapper) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, pk_range_wrapper)) self.logger = logging.getLogger('azure.cosmos.ServiceResponseRetryPolicy') def ShouldRetry(self): @@ -59,4 +59,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py index e52fbe996e11..69b9d52f286d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py @@ -58,8 +58,8 @@ def __init__(self, endpoint_discovery_enable, global_endpoint_manager, pk_range_ # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, - self.pk_range_wrapper) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)) self.request.route_to_location(self.location_endpoint) def ShouldRetry(self, _exception): @@ -100,8 +100,8 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, - self.pk_range_wrapper) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)) self.request.route_to_location(self.location_endpoint) return True @@ -116,7 +116,7 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, - self.pk_range_wrapper) + self.location_endpoint = (self.global_endpoint_manager + .resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper)) self.request.route_to_location(self.location_endpoint) return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 9e7e38b31322..e41881429b20 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -110,7 +110,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin if global_endpoint_manager.is_circuit_breaker_applicable(request_params): # Circuit breaker is applicable, so we need to use the endpoint from the request pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(request_params) - base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) + base_url = global_endpoint_manager.resolve_service_endpoint_for_partition(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index 69bc973c3346..b77ce1a69f13 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -55,4 +55,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) + return self.global_endpoint_manager.resolve_service_endpoint_for_partition(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 53e87c1d0211..79e674eaa31c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -79,7 +79,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p if global_endpoint_manager.is_circuit_breaker_applicable(request_params): # Circuit breaker is applicable, so we need to use the endpoint from the request pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(request_params) - base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) + base_url = global_endpoint_manager.resolve_service_endpoint_for_partition(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 3e19f0ddc000..9164ce550dfe 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2875,10 +2875,10 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: return [] initial_headers = self.default_headers.copy() - cont_prop = kwargs.pop("containerProperties", None) - if cont_prop: - cont_prop = await cont_prop() - + cont_prop_func = kwargs.pop("containerProperties", None) + cont_prop = None + if cont_prop_func: + cont_prop = await cont_prop_func() # Copy to make sure that default_headers won't be changed. if query is None: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 0d51aa83512a..705f6973892b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,7 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from typing import Tuple, Dict, Any, Optional +from typing import Tuple, Dict, Any from azure.core.exceptions import AzureError from azure.cosmos import DatabaseAccount @@ -34,7 +34,6 @@ from .. import exceptions from .._location_cache import LocationCache, current_time_millis from .._request_object import RequestObject -from .._routing.routing_range import PartitionKeyRangeWrapper # pylint: disable=protected-access @@ -71,10 +70,9 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint( + def _resolve_service_endpoint( self, - request: RequestObject, - pk_range_wrapper: Optional[PartitionKeyRangeWrapper] # pylint: disable=unused-argument + request: RequestObject ) -> str: return self.location_cache.resolve_service_endpoint(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 92705ba39907..1f7d9ccc49e4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -21,7 +21,6 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging from typing import TYPE_CHECKING, Optional from azure.cosmos import PartitionKey @@ -36,7 +35,6 @@ if TYPE_CHECKING: from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection -logger = logging.getLogger("azure.cosmos.aio._GlobalPartitionEndpointManagerForCircuitBreakerAsync") # pylint: disable=protected-access class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): @@ -45,7 +43,6 @@ class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManag geo-replicated database accounts. """ - def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).__init__(client) self.global_partition_endpoint_manager_core = ( @@ -55,8 +52,9 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> Optional[Part if HttpHeaders.IntendedCollectionRID in request.headers: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] else: - logger.warning("Illegal state: the request does not contain container information. " - "Circuit breaker cannot be performed.") + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain container information. " + "Circuit breaker cannot be performed.") return None properties = self.client._container_properties_cache[container_rid] # get relevant information from container cache to get the overlapping ranges @@ -75,10 +73,16 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> Optional[Part pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] epk_range = await (self.client._routing_map_provider .get_range_by_partition_key_range_id(container_link, pk_range_id)) + if not epk_range: + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: partition key range cache not initialized correctly. " + "Circuit breaker cannot be performed.") + return None partition_range = Range.PartitionKeyRangeToRange(epk_range) else: - logger.warning("Illegal state: the request does not contain partition information. " - "Circuit breaker cannot be performed.") + self.global_partition_endpoint_manager_core.log_warn_or_debug( + "Illegal state: the request does not contain partition information. " + "Circuit breaker cannot be performed.") return None return PartitionKeyRangeWrapper(partition_range, container_rid) @@ -95,7 +99,7 @@ async def record_failure( if pk_range_wrapper: self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) - def resolve_service_endpoint( + def resolve_service_endpoint_for_partition( self, request: RequestObject, pk_range_wrapper: Optional[PartitionKeyRangeWrapper] @@ -104,8 +108,7 @@ def resolve_service_endpoint( self.global_partition_endpoint_manager_core.check_stale_partition_info(request, pk_range_wrapper) request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, pk_range_wrapper) - return (super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self) - .resolve_service_endpoint(request, pk_range_wrapper)) + return self._resolve_service_endpoint(request) def mark_partition_unavailable( self, diff --git a/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py b/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py index 8c1a2fbfb01e..a7aee626bfb8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py +++ b/sdk/cosmos/azure-cosmos/tests/test_globaldb_mock.py @@ -75,7 +75,7 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self._ReadEndpoint - def resolve_service_endpoint(self, request, pk_range_wrapper): + def resolve_service_endpoint_for_partition(self, request, pk_range_wrapper): return def refresh_endpoint_list(self, database_account, **kwargs): diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 2e038e525874..a88bf5011f2a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -191,6 +191,7 @@ class TestPerPartitionCircuitBreakerMMAsync: TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(default_endpoint, self.master_key, preferred_locations=[REGION_1, REGION_2], multiple_write_locations=True, diff --git a/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py b/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py index b034287fbd6d..bb2a986e6471 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py +++ b/sdk/cosmos/azure-cosmos/tests/test_streaming_failover.py @@ -143,8 +143,8 @@ def test_retry_policy_does_not_mark_null_locations_unavailable(self): endpoint_manager.mark_endpoint_unavailable_for_read = self._mock_mark_endpoint_unavailable_for_read self.original_mark_endpoint_unavailable_for_write_function = endpoint_manager.mark_endpoint_unavailable_for_write endpoint_manager.mark_endpoint_unavailable_for_write = self._mock_mark_endpoint_unavailable_for_write - self.original_resolve_service_endpoint = endpoint_manager.resolve_service_endpoint - endpoint_manager.resolve_service_endpoint = self._mock_resolve_service_endpoint + self.original_resolve_service_endpoint = endpoint_manager.resolve_service_endpoint_for_partition + endpoint_manager.resolve_service_endpoint_for_partition = self._mock_resolve_service_endpoint # Read and write counters count the number of times the endpoint manager's # mark_endpoint_unavailable_for_read() and mark_endpoint_unavailable_for_read() From 08075bdc4b9c0cc70a26d04943bb6a6672f1d59f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 13 May 2025 10:53:52 -0400 Subject: [PATCH 146/152] react to comments --- ...tition_endpoint_manager_circuit_breaker.py | 10 ---------- ...n_endpoint_manager_circuit_breaker_core.py | 4 ---- .../azure/cosmos/_partition_health_tracker.py | 11 ----------- .../azure/cosmos/_retry_utility.py | 3 ++- .../cosmos/_service_request_retry_policy.py | 9 +++------ ..._endpoint_manager_circuit_breaker_async.py | 8 -------- .../azure/cosmos/aio/_retry_utility_async.py | 3 ++- .../test_per_partition_circuit_breaker_mm.py | 10 +++++----- ..._per_partition_circuit_breaker_mm_async.py | 19 +++++++++---------- ...st_per_partition_circuit_breaker_sm_mrr.py | 8 ++++---- ..._partition_circuit_breaker_sm_mrr_async.py | 9 ++++----- 11 files changed, 29 insertions(+), 65 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index bd980a038f61..0870979deb8c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -41,7 +41,6 @@ class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): geo-replicated database accounts. """ - def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) self.global_partition_endpoint_manager_core = ( @@ -90,7 +89,6 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK return PartitionKeyRangeWrapper(partition_range, container_rid) - def record_failure( self, request: RequestObject @@ -111,14 +109,6 @@ def resolve_service_endpoint_for_partition( pk_range_wrapper) return self._resolve_service_endpoint(request) - def mark_partition_unavailable( - self, - request: RequestObject, - pk_range_wrapper: Optional[PartitionKeyRangeWrapper] - ) -> None: - if pk_range_wrapper: - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) - def record_success( self, request: RequestObject diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index a950b5557f35..93faf9b7a8c5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -109,10 +109,6 @@ def add_excluded_locations_to_request( ) return request - def mark_partition_unavailable(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: - location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - self.partition_health_tracker.mark_partition_unavailable(pk_range_wrapper, location) - def record_success( self, request: RequestObject, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index a8400bf552d6..8c00f10508dd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -30,7 +30,6 @@ from azure.cosmos._request_object import RequestObject from ._constants import _Constants as Constants - MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 MAX_UNAVAILABLE_TIME = 1200 * 1000 # milliseconds REFRESH_INTERVAL = 60 * 1000 # milliseconds @@ -58,7 +57,6 @@ def __init__(self) -> None: self.write_consecutive_failure_count: int = 0 self.unavailability_info: Dict[str, Any] = {} - def reset_health_stats(self) -> None: self.write_failure_count = 0 self.read_failure_count = 0 @@ -118,17 +116,12 @@ class _PartitionHealthTracker(object): This internal class implements the logic for tracking health thresholds for a partition. """ - def __init__(self) -> None: # partition -> regions -> health info self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} self.last_refresh = current_time_millis() self.stale_partition_lock = threading.Lock() - def mark_partition_unavailable(self, pk_range_wrapper: PartitionKeyRangeWrapper, location: str) -> None: - # mark the partition key range as unavailable - self._transition_health_status_on_failure(pk_range_wrapper, location) - def _transition_health_status_on_failure( self, pk_range_wrapper: PartitionKeyRangeWrapper, @@ -156,8 +149,6 @@ def _transition_health_status_on_failure( partition_health_info.transition_health_status(UNHEALTHY_TENTATIVE, current_time) self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = partition_health_info - - def _transition_health_status_on_success( self, pk_range_wrapper: PartitionKeyRangeWrapper, @@ -206,7 +197,6 @@ def get_unhealthy_locations( excluded_locations.append(location) return excluded_locations - def add_failure( self, pk_range_wrapper: PartitionKeyRangeWrapper, @@ -296,7 +286,6 @@ def add_success(self, pk_range_wrapper: PartitionKeyRangeWrapper, operation_type health_info.read_consecutive_failure_count = 0 self._transition_health_status_on_success(pk_range_wrapper, location) - def _reset_partition_health_tracker_stats(self) -> None: for locations in self.pk_range_wrapper_to_health_info.values(): for health_info in locations.values(): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 0ee2cd1146cf..ab34f07d0f7c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -335,7 +335,6 @@ def send(self, request): if (not _has_database_account_header(request.http_request.headers) and not request_params.healthy_tentative_location): if retry_settings['connect'] > 0: - global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) @@ -350,6 +349,8 @@ def send(self, request): raise err # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: + # record the failure for circuit breaker tracking for retrys in connection retry policy + # retries in the execute function will mark those failures global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index a0853c0b1065..5b4faf75df84 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -45,12 +45,9 @@ def ShouldRetry(self): if self.request.resource_type == ResourceType.DatabaseAccount: return False - if self.global_endpoint_manager.is_circuit_breaker_applicable(self.request): - self.global_endpoint_manager.mark_partition_unavailable(self.request, self.pk_range_wrapper) - else: - refresh_cache = self.request.last_routed_location_endpoint_within_region is not None - # This logic is for the last retry and mark the region unavailable - self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) + refresh_cache = self.request.last_routed_location_endpoint_within_region is not None + # This logic is for the last retry and mark the region unavailable + self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) # Check if it is safe to do another retry if self.in_region_retry_count >= self.total_in_region_retries: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 1f7d9ccc49e4..ae5d82c760c1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -110,14 +110,6 @@ def resolve_service_endpoint_for_partition( pk_range_wrapper) return self._resolve_service_endpoint(request) - def mark_partition_unavailable( - self, - request: RequestObject, - pk_range_wrapper: Optional[PartitionKeyRangeWrapper] - ) -> None: - if pk_range_wrapper: - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) - async def record_success( self, request: RequestObject diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 93576224c2dc..9d29a53a631d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -293,7 +293,6 @@ async def send(self, request): if (not _has_database_account_header(request.http_request.headers) and not request_params.healthy_tentative_location): if retry_settings['connect'] > 0: - await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) @@ -313,6 +312,8 @@ async def send(self, request): or _has_read_retryable_headers(request.http_request.headers)): # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: + # record the failure for circuit breaker tracking for retrys in connection retry policy + # retries in the execute function will mark those failures await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py index aa05f31f6282..be6b3f0f9ab6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm.py @@ -327,9 +327,9 @@ def test_service_request_error(self, read_operation, write_operation): expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions(global_endpoint_manager, 0) # there shouldn't be region marked as unavailable - assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 # recover partition # remove faults and reduce initial recover time and perform a write @@ -341,7 +341,7 @@ def test_service_request_error(self, read_operation, write_operation): fault_injection_container, doc['id'], PK_VALUE, - uri_down) + expected_uri) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) @@ -357,9 +357,9 @@ def test_service_request_error(self, read_operation, write_operation): expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions(global_endpoint_manager, 0) # there shouldn't be region marked as unavailable - assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 # test cosmos client timeout diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index a88bf5011f2a..d4d9d28eb5af 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -191,7 +191,6 @@ class TestPerPartitionCircuitBreakerMMAsync: TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): - os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(default_endpoint, self.master_key, preferred_locations=[REGION_1, REGION_2], multiple_write_locations=True, @@ -430,12 +429,11 @@ async def test_service_request_error_async(self, read_operation, write_operation expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - validate_unhealthy_partitions(global_endpoint_manager, 1) - # there shouldn't be region marked as unavailable - assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + validate_unhealthy_partitions(global_endpoint_manager, 0) + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 # recover partition - # remove faults and reduce initial recover time and perform a write + # remove faults and reduce initial recover time and perform a read original_unavailable_time = _partition_health_tracker.INITIAL_UNAVAILABLE_TIME _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 custom_transport.faults = [] @@ -444,10 +442,12 @@ async def test_service_request_error_async(self, read_operation, write_operation fault_injection_container, doc['id'], PK_VALUE, - uri_down) + expected_uri) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) + # ppcb should not regress connection timeouts marking the region as unavailable + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) @@ -460,9 +460,8 @@ async def test_service_request_error_async(self, read_operation, write_operation expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - validate_unhealthy_partitions(global_endpoint_manager, 1) - # there shouldn't be region marked as unavailable - assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + validate_unhealthy_partitions(global_endpoint_manager, 0) + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 await cleanup_method([custom_setup, setup]) @@ -506,7 +505,7 @@ async def concurrent_upsert(): assert number_of_errors == 1 finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time - await cleanup_method([custom_setup]) + await cleanup_method([custom_setup, setup]) # test cosmos client timeout diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py index 2147b8c037ec..34ce4109dd61 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr.py @@ -129,9 +129,9 @@ def test_service_request_error(self, read_operation, write_operation): expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - validate_unhealthy_partitions(global_endpoint_manager, 1) - # there shouldn't be region marked as unavailable - assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + # there shouldn't be partition marked as unavailable + validate_unhealthy_partitions(global_endpoint_manager, 0) + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 # recover partition # remove faults and reduce initial recover time and perform a write @@ -143,7 +143,7 @@ def test_service_request_error(self, read_operation, write_operation): fault_injection_container, doc['id'], PK_VALUE, - uri_down) + expected_uri) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py index cb4fdc128c75..665eff84aa0a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_sm_mrr_async.py @@ -165,9 +165,9 @@ async def test_service_request_error_async(self, read_operation, write_operation expected_uri) global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager - validate_unhealthy_partitions(global_endpoint_manager, 1) + validate_unhealthy_partitions(global_endpoint_manager, 0) # there shouldn't be region marked as unavailable - assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 0 + assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 # recover partition # remove faults and reduce initial recover time and perform a write @@ -179,14 +179,14 @@ async def test_service_request_error_async(self, read_operation, write_operation fault_injection_container, doc['id'], PK_VALUE, - uri_down) + expected_uri) finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) - # The global endpoint would be used for the write operation + # The global endpoint would be used for the write operation in single region write expected_uri = self.host await perform_write_operation(write_operation, container, @@ -197,7 +197,6 @@ async def test_service_request_error_async(self, read_operation, write_operation global_endpoint_manager = fault_injection_container.client_connection._global_endpoint_manager validate_unhealthy_partitions(global_endpoint_manager, 0) - # there shouldn't be region marked as unavailable assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 await cleanup_method([custom_setup, setup]) From da362fcc2e5b22602781d222803e89ce96c2240d Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 27 May 2025 09:52:18 -0700 Subject: [PATCH 147/152] react to async client changes from merge --- .../_global_partition_endpoint_manager_circuit_breaker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 0870979deb8c..538eb1c2edb1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -58,7 +58,7 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK "Illegal state: the request does not contain container information. " "Circuit breaker cannot be performed.") return None - properties = self.Client._container_properties_cache[container_rid] # pylint: disable=protected-access + properties = self.client._container_properties_cache[container_rid] # pylint: disable=protected-access # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] partition_key_definition = properties["partitionKey"] @@ -68,12 +68,12 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] # pylint: disable=protected-access - partition_ranges = (self.Client._routing_map_provider # pylint: disable=protected-access + partition_ranges = (self.client._routing_map_provider # pylint: disable=protected-access .get_overlapping_ranges(container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) elif HttpHeaders.PartitionKeyRangeID in request.headers: pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] - epk_range =(self.Client._routing_map_provider # pylint: disable=protected-access + epk_range =(self.client._routing_map_provider # pylint: disable=protected-access .get_range_by_partition_key_range_id(container_link, pk_range_id)) if not epk_range: self.global_partition_endpoint_manager_core.log_warn_or_debug( From c29c52a6b17b5e0374d6bd5bdda63351e2f4b080 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 27 May 2025 10:59:31 -0700 Subject: [PATCH 148/152] react to comments --- .../_global_partition_endpoint_manager_circuit_breaker.py | 4 +++- ...global_partition_endpoint_manager_circuit_breaker_async.py | 4 +++- sdk/cosmos/azure-cosmos/pytest.ini | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 538eb1c2edb1..2eda20c926d0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -62,7 +62,9 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"], + version=partition_key_definition["version"]) if HttpHeaders.PartitionKey in request.headers: partition_key_value = request.headers[HttpHeaders.PartitionKey] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index ae5d82c760c1..78e8b551ee7a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -60,7 +60,9 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> Optional[Part # get relevant information from container cache to get the overlapping ranges container_link = properties["container_link"] partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"], + version=partition_key_definition["version"]) if HttpHeaders.PartitionKey in request.headers: partition_key_value = request.headers[HttpHeaders.PartitionKey] diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 6255aded49ad..aabe78b51f08 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -6,4 +6,4 @@ markers = cosmosSplit: marks test where there are partition splits on CosmosDB live account. cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled. cosmosCircuitBreaker: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled. - cosmosCircuitBreakerMultiRegion: marks tests running on Cosmos DB live account with per partition circuit breaker enabled. + cosmosCircuitBreakerMultiRegion: marks tests running on Cosmos DB live account with one write region and multiple read regions and per partition circuit breaker enabled. From 2feca2ab0759d6e94877f20e66e98eea09023d69 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 27 May 2025 12:25:15 -0700 Subject: [PATCH 149/152] fix tests --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- .../azure/cosmos/_retry_utility.py | 2 +- .../azure/cosmos/aio/_container.py | 6 +--- .../aio/_cosmos_client_connection_async.py | 14 ++++---- .../azure/cosmos/aio/_retry_utility_async.py | 2 +- .../tests/test_excluded_locations.py | 35 +++++++------------ ..._per_partition_circuit_breaker_mm_async.py | 2 +- 7 files changed, 23 insertions(+), 40 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 3313807b7987..9c13da7e8300 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,9 +1,9 @@ ## Release History ### 4.12.0b2 (Unreleased) -* Added ability to use request level `excluded_locations` on metadata calls, such as getting container properties. See [PR 40905](https://github.com/Azure/azure-sdk-for-python/pull/40905) #### Features Added +* Added ability to use request level `excluded_locations` on metadata calls, such as getting container properties. See [PR 40905](https://github.com/Azure/azure-sdk-for-python/pull/40905) * Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). #### Bugs Fixed diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index ab34f07d0f7c..d34a1068a41d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -349,7 +349,7 @@ def send(self, request): raise err # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: - # record the failure for circuit breaker tracking for retrys in connection retry policy + # record the failure for circuit breaker tracking for retries in connection retry policy # retries in the execute function will mark those failures global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 8614de5d9af7..9105b759d3e2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -181,10 +181,6 @@ async def read( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. - :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations - in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. - If all preferred locations were excluded, primary/hub location will be used. - This excluded_location will override existing excluded_locations in client level. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Raised if the container couldn't be retrieved. This includes if the container does not exist. :returns: Dict representing the retrieved container. @@ -422,7 +418,7 @@ def read_all_items( response_hook.clear() if self.container_link in self.__get_client_container_caches(): feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] - kwargs["containerProperties"] = self._get_properties + kwargs["containerProperties"] = self._get_properties_with_feed_options items = self.client_connection.ReadItems( collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 29c885d89cc9..205a5ed4e376 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2899,7 +2899,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: cont_prop_func = kwargs.pop("containerProperties", None) cont_prop = None if cont_prop_func: - cont_prop = await cont_prop_func(options) + cont_prop = await cont_prop_func(options) # get properties with feed options # Copy to make sure that default_headers won't be changed. if query is None: @@ -2950,13 +2950,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: request_params.set_excluded_location_from_options(options) # check if query has prefix partition key - cont_prop = kwargs.pop("containerProperties", None) partition_key_value = options.get("partitionKey", None) is_prefix_partition_query = False partition_key_obj = None if cont_prop and partition_key_value is not None: - properties = await cont_prop(options) # get properties with feed options - partition_key_definition = properties["partitionKey"] + partition_key_definition = cont_prop["partitionKey"] partition_key_obj = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"], version=partition_key_definition["version"]) @@ -2965,9 +2963,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if is_prefix_partition_query and partition_key_obj: # here get the overlapping ranges req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) - feedrangeEPK = partition_key_obj._get_epk_range_for_prefix_partition_key( + feed_range_epk = partition_key_obj._get_epk_range_for_prefix_partition_key( partition_key_value) # cspell:disable-line - over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, [feedrangeEPK], options) + over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, [feed_range_epk], options) results: Dict[str, Any] = {} # For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over # lapping physical partition. The EPK sub range will be one of four: @@ -2982,8 +2980,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range) # Since the range min and max are all Upper Cased string Hex Values, # we can compare the values lexicographically - EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feedrangeEPK.min), - range_max=min(single_range.max, feedrangeEPK.max), + EPK_sub_range = routing_range.Range(range_min=max(single_range.min, feed_range_epk.min), + range_max=min(single_range.max, feed_range_epk.max), isMinInclusive=True, isMaxInclusive=False) if single_range.min == EPK_sub_range.min and EPK_sub_range.max == single_range.max: # The Epk Sub Range spans exactly one physical partition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 9d29a53a631d..20d048b3cef6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -312,7 +312,7 @@ async def send(self, request): or _has_read_retryable_headers(request.http_request.headers)): # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: - # record the failure for circuit breaker tracking for retrys in connection retry policy + # record the failure for circuit breaker tracking for retries in connection retry policy # retries in the execute function will mark those failures await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 1dff2fb6f434..e5f7a4c281d5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -51,11 +51,6 @@ class TestDataType: L2 = "West US" L3 = "East US 2" -# L0 = "Default" -# L1 = "East US 2" -# L2 = "East US" -# L3 = "West US 2" - CLIENT_ONLY_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No excluded location @@ -115,12 +110,12 @@ def read_item_test_data(): [L1], # 3 ] client_and_request_output_data = [ - [L2], # 0 + [L2, L1], # 0 [L2], # 1 - [L2], # 2 - [L1], # 3 + [L2, L1], # 2 + [L1, L2], # 3 [L1], # 4 - [L1], # 5 + [L1, L2], # 5 [L1], # 6 [L1], # 7 ] @@ -188,20 +183,14 @@ def verify_endpoint(messages, client, expected_locations, multiple_write_locatio # get location actual_locations = set() for req_url in req_urls: - match = re.search(r"x\-ms\-thinclient\-proxy\-resource\-type': '([^']+)'", req_url) - if match: - resource_type = match.group(1) - # only check the document and partition key requests because that is where excluded locations - # is applicable - if resource_type in (ResourceType.Document, ResourceType.PartitionKey): - if req_url.startswith(default_endpoint): - actual_locations.add(L0) - else: - for endpoint in location_mapping: - if req_url.startswith(endpoint): - location = location_mapping[endpoint] - actual_locations.add(location) - break + if req_url.startswith(default_endpoint): + actual_locations.add(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.add(location) + break assert actual_locations == set(expected_locations) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index d4d9d28eb5af..79f50fa2a530 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -446,7 +446,7 @@ async def test_service_request_error_async(self, read_operation, write_operation finally: _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time validate_unhealthy_partitions(global_endpoint_manager, 0) - # ppcb should not regress connection timeouts marking the region as unavailable + # per partition circuit breaker should not regress connection timeouts marking the region as unavailable assert len(global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) == 1 custom_transport.add_fault(predicate, From c0441bfa77fbf23a440c496a2a07829de14046db Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 27 May 2025 17:07:07 -0700 Subject: [PATCH 150/152] react to comments and fix tests --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 2 +- .../_container_recreate_retry_policy.py | 2 +- .../azure/cosmos/_cosmos_client_connection.py | 6 +-- .../azure/cosmos/_global_endpoint_manager.py | 3 +- .../azure/cosmos/_location_cache.py | 6 +-- .../azure/cosmos/_partition_health_tracker.py | 3 +- .../azure/cosmos/_retry_utility.py | 5 ++- .../azure-cosmos/azure/cosmos/_utils.py | 5 ++- .../azure/cosmos/aio/_container.py | 41 +++++++++---------- .../aio/_cosmos_client_connection_async.py | 12 +++--- .../aio/_global_endpoint_manager_async.py | 3 +- .../azure-cosmos/azure/cosmos/container.py | 39 +++++++++--------- .../tests/test_excluded_locations.py | 14 +++---- .../tests/test_excluded_locations_async.py | 4 +- ..._per_partition_circuit_breaker_mm_async.py | 1 + 15 files changed, 75 insertions(+), 71 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index af77a499ffa2..5b7fd0254421 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -877,7 +877,7 @@ def _format_batch_operations( return final_operations -def _set_properties_cache(properties: Dict[str, Any], container_link: str) -> Dict[str, Any]: +def _build_properties_cache(properties: Dict[str, Any], container_link: str) -> Dict[str, Any]: return { "_self": properties.get("_self", None), "_rid": properties.get("_rid", None), "partitionKey": properties.get("partitionKey", None), "container_link": container_link diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py index 4080d3272d85..53ee57b8c3f8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_container_recreate_retry_policy.py @@ -72,7 +72,7 @@ def ShouldRetry(self, exception: Optional[Any]) -> bool: def __find_container_link_with_rid(self, container_properties_caches: Optional[Dict[str, Any]], rid: str) -> \ Optional[str]: if container_properties_caches: - if container_properties_caches.get(rid): + if rid in container_properties_caches: return container_properties_caches[rid]["container_link"] # If we cannot get the container link at all it might mean the cache was somehow deleted, this isn't # a container request so this retry is not needed. Return None. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 185e6d244b18..508611419ade 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -57,7 +57,7 @@ from . import documents from . import http_constants, exceptions from ._auth_policy import CosmosBearerTokenCredentialPolicy -from ._base import _set_properties_cache +from ._base import _build_properties_cache from ._change_feed.change_feed_iterable import ChangeFeedIterable from ._change_feed.change_feed_state import ChangeFeedState from ._constants import _Constants as Constants @@ -3393,7 +3393,7 @@ def _refresh_container_properties_cache(self, container_link: str): # If container properties cache is stale, refresh it by reading the container. container = self.ReadContainer(container_link, options=None) # Only cache Container Properties that will not change in the lifetime of the container - self._set_container_properties_cache(container_link, _set_properties_cache(container, container_link)) + self._set_container_properties_cache(container_link, _build_properties_cache(container, container_link)) def _UpdateSessionIfRequired( self, @@ -3436,5 +3436,5 @@ def _get_partition_key_definition( else: container = self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") - self._set_container_properties_cache(collection_link, _set_properties_cache(container, collection_link)) + self._set_container_properties_cache(collection_link, _build_properties_cache(container, collection_link)) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 09b4a9d2a8a7..c4612b629c53 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -32,7 +32,8 @@ from . import exceptions from ._request_object import RequestObject from .documents import DatabaseAccount -from ._location_cache import LocationCache, current_time_millis +from ._location_cache import LocationCache +from ._utils import current_time_millis # pylint: disable=protected-access diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index c6f00b95ebef..90578c63e5dd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -24,7 +24,6 @@ """ import collections import logging -import time from typing import Set, Mapping, List from urllib.parse import urlparse @@ -151,9 +150,6 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: List[Re return applicable_regional_routing_contexts -def current_time_millis() -> int: - return int(round(time.time() * 1000)) - class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes def __init__( @@ -183,7 +179,7 @@ def get_read_regional_routing_contexts(self): def get_location_from_endpoint(self, endpoint: str) -> str: if endpoint in self.account_locations_by_read_endpoints: - return self.account_locations_by_write_endpoints[endpoint] + return self.account_locations_by_read_endpoints[endpoint] return self.account_write_locations[0] def get_write_regional_routing_context(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 8c00f10508dd..d1b2307c9093 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -26,8 +26,9 @@ import os from typing import Dict, Any, List from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper -from azure.cosmos._location_cache import current_time_millis, EndpointOperationType +from azure.cosmos._location_cache import EndpointOperationType from azure.cosmos._request_object import RequestObject +from ._utils import current_time_millis from ._constants import _Constants as Constants MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index d34a1068a41d..91145ef217ba 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -45,7 +45,10 @@ # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches - +# args [0] is the request object +# args [1] is the connection policy +# args [2] is the pipeline client +# args [3] is the http request def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylint: disable=too-many-locals """Executes the function with passed parameters applying all retry policies diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py index 1b3d0370e6ef..c490a82a8a05 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_utils.py @@ -26,6 +26,7 @@ import re import base64 import json +import time from typing import Any, Dict, Optional from ._version import VERSION @@ -37,7 +38,6 @@ def get_user_agent() -> str: user_agent = "azsdk-python-cosmos/{} Python/{} ({})".format(VERSION, python_version, os_name) return user_agent - def get_user_agent_async() -> str: os_name = safe_user_agent_header(platform.platform()) python_version = safe_user_agent_header(platform.python_version()) @@ -69,3 +69,6 @@ def get_index_metrics_info(delimited_string: Optional[str]) -> Dict[str, Any]: return result except (json.JSONDecodeError, ValueError): return {} + +def current_time_millis() -> int: + return int(round(time.time() * 1000)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 9105b759d3e2..d167a3e469ed 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -40,7 +40,7 @@ _deserialize_throughput, _replace_throughput, GenerateGuidId, - _set_properties_cache + _build_properties_cache ) from .._change_feed.feed_range_internal import FeedRangeInternalEpk from .._cosmos_responses import CosmosDict, CosmosList @@ -94,16 +94,16 @@ def __init__( self._scripts: Optional[ScriptsProxy] = None if properties: self.client_connection._set_container_properties_cache(self.container_link, - _set_properties_cache(properties, - self.container_link)) + _build_properties_cache(properties, + self.container_link)) def __repr__(self) -> str: return "".format(self.container_link)[:1024] - async def _get_properties_with_feed_options(self, feed_options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def _get_properties_with_options(self, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: kwargs = {} - if feed_options and "excludedLocations" in feed_options: - kwargs['excluded_locations'] = feed_options['excludedLocations'] + if options and "excludedLocations" in options: + kwargs['excluded_locations'] = options['excludedLocations'] return await self._get_properties(**kwargs) async def _get_properties(self, **kwargs: Any) -> Dict[str, Any]: @@ -151,7 +151,7 @@ async def _get_epk_range_for_partition_key( self, partition_key_value: PartitionKeyType, feed_options: Optional[Dict[str, Any]] = None) -> Range: - container_properties = await self._get_properties_with_feed_options(feed_options) + container_properties = await self._get_properties_with_options(feed_options) partition_key_definition = container_properties["partitionKey"] partition_key = PartitionKey( path=partition_key_definition["paths"], @@ -204,8 +204,8 @@ async def read( request_options["populateQuotaInfo"] = populate_quota_info container = await self.client_connection.ReadContainer(self.container_link, options=request_options, **kwargs) # Only cache Container Properties that will not change in the lifetime of the container - self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access - _set_properties_cache(container, self.container_link)) + self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access + _build_properties_cache(container, self.container_link)) return container @distributed_trace_async @@ -286,7 +286,7 @@ async def create_item( request_options["disableAutomaticIdGeneration"] = not enable_automatic_id_generation if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.CreateItem( @@ -361,7 +361,7 @@ async def read_item( if max_integrated_cache_staleness_in_ms is not None: validate_cache_staleness_value(max_integrated_cache_staleness_in_ms) request_options["maxIntegratedCacheStaleness"] = max_integrated_cache_staleness_in_ms - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return await self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) @@ -418,7 +418,7 @@ def read_all_items( response_hook.clear() if self.container_link in self.__get_client_container_caches(): feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] - kwargs["containerProperties"] = self._get_properties_with_feed_options + kwargs["containerProperties"] = self._get_properties_with_options items = self.client_connection.ReadItems( collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs @@ -523,9 +523,9 @@ def query_items( feed_options["populateIndexMetrics"] = populate_index_metrics if enable_scan_in_query is not None: feed_options["enableScanInQuery"] = enable_scan_in_query + kwargs["containerProperties"] = self._get_properties_with_options if partition_key is not None: feed_options["partitionKey"] = self._set_partition_key(partition_key) - kwargs["containerProperties"] = self._get_properties_with_feed_options else: feed_options["enableCrossPartitionQuery"] = True if max_integrated_cache_staleness_in_ms: @@ -539,7 +539,6 @@ def query_items( response_hook.clear() if self.container_link in self.__get_client_container_caches(): feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] - kwargs["containerProperties"] = self._get_properties_with_feed_options items = self.client_connection.QueryItems( database_or_container_link=self.container_link, @@ -772,7 +771,7 @@ def query_items_change_feed( # pylint: disable=unused-argument change_feed_state_context["continuation"] = feed_options.pop("continuation") feed_options["changeFeedStateContext"] = change_feed_state_context - feed_options["containerProperties"] = self._get_properties_with_feed_options(feed_options) + feed_options["containerProperties"] = self._get_properties_with_options(feed_options) response_hook = kwargs.pop("response_hook", None) if hasattr(response_hook, "clear"): @@ -854,7 +853,7 @@ async def upsert_item( kwargs["throughput_bucket"] = throughput_bucket request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.UpsertItem( @@ -937,7 +936,7 @@ async def replace_item( kwargs["throughput_bucket"] = throughput_bucket request_options = _build_options(kwargs) request_options["disableAutomaticIdGeneration"] = True - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = await self.client_connection.ReplaceItem( @@ -1021,7 +1020,7 @@ async def patch_item( request_options["partitionKey"] = await self._set_partition_key(partition_key) if filter_predicate is not None: request_options["filterPredicate"] = filter_predicate - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) @@ -1093,7 +1092,7 @@ async def delete_item( kwargs["throughput_bucket"] = throughput_bucket request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] document_link = self._get_document_link(item) @@ -1354,7 +1353,7 @@ async def delete_all_items_by_partition_key( request_options = _build_options(kwargs) # regardless if partition key is valid we set it as invalid partition keys are set to a default empty value request_options["partitionKey"] = await self._set_partition_key(partition_key) - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] await self.client_connection.DeleteAllItemsByPartitionKey(collection_link=self.container_link, @@ -1422,7 +1421,7 @@ async def execute_item_batch( request_options = _build_options(kwargs) request_options["partitionKey"] = await self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True - await self._get_properties() + await self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return await self.client_connection.Batch( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 205a5ed4e376..23fb5ef45c0e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -52,7 +52,7 @@ _GlobalPartitionEndpointManagerForCircuitBreakerAsync) from .. import _base as base -from .._base import _set_properties_cache +from .._base import _build_properties_cache from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState @@ -2282,7 +2282,7 @@ def QueryItems( collection_id = base.GetResourceIdOrFullNameFromLink(database_or_container_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: - await kwargs["containerProperties"]() + await kwargs["containerProperties"](options) new_options = dict(options) new_options["containerRID"] = self.__container_properties_cache[database_or_container_link]["_rid"] return ( @@ -2965,7 +2965,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: req_headers.pop(http_constants.HttpHeaders.PartitionKey, None) feed_range_epk = partition_key_obj._get_epk_range_for_prefix_partition_key( partition_key_value) # cspell:disable-line - over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, [feed_range_epk], options) + over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, + [feed_range_epk], + options) results: Dict[str, Any] = {} # For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over # lapping physical partition. The EPK sub range will be one of four: @@ -3235,7 +3237,7 @@ async def _refresh_container_properties_cache(self, container_link: str): # If container properties cache is stale, refresh it by reading the container. container = await self.ReadContainer(container_link, options=None) # Only cache Container Properties that will not change in the lifetime of the container - self._set_container_properties_cache(container_link, _set_properties_cache(container, container_link)) + self._set_container_properties_cache(container_link, _build_properties_cache(container, container_link)) async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, excluded_locations: Optional[str] = None, @@ -3339,5 +3341,5 @@ async def _get_partition_key_definition( else: container = await self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") - self._set_container_properties_cache(collection_link, _set_properties_cache(container, collection_link)) + self._set_container_properties_cache(collection_link, _build_properties_cache(container, collection_link)) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 705f6973892b..2d0184468149 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -32,7 +32,8 @@ from .. import _constants as constants from .. import exceptions -from .._location_cache import LocationCache, current_time_millis +from .._location_cache import LocationCache +from .._utils import current_time_millis from .._request_object import RequestObject # pylint: disable=protected-access diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 54b670e1f0ea..5647a66a99f7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -37,7 +37,7 @@ _deserialize_throughput, _replace_throughput, GenerateGuidId, - _set_properties_cache + _build_properties_cache ) from ._change_feed.feed_range_internal import FeedRangeInternalEpk from ._cosmos_client_connection import CosmosClientConnection @@ -108,16 +108,16 @@ def __init__( self._scripts: Optional[ScriptsProxy] = None if properties: self.client_connection._set_container_properties_cache(self.container_link, - _set_properties_cache(properties, - self.container_link)) + _build_properties_cache(properties, + self.container_link)) def __repr__(self) -> str: return "".format(self.container_link)[:1024] - def _get_properties_with_feed_options(self, feed_options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def _get_properties_with_options(self, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: kwargs = {} - if feed_options and "excludedLocations" in feed_options: - kwargs['excluded_locations'] = feed_options['excludedLocations'] + if options and "excludedLocations" in options: + kwargs['excluded_locations'] = options['excludedLocations'] return self._get_properties(**kwargs) def _get_properties(self, **kwargs: Any) -> Dict[str, Any]: @@ -214,8 +214,8 @@ def read( # pylint:disable=docstring-missing-param request_options["populateQuotaInfo"] = populate_quota_info container = self.client_connection.ReadContainer(self.container_link, options=request_options, **kwargs) # Only cache Container Properties that will not change in the lifetime of the container - self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access - _set_properties_cache(container, self.container_link)) + self.client_connection._set_container_properties_cache(self.container_link, # pylint: disable=protected-access + _build_properties_cache(container, self.container_link)) return container @distributed_trace @@ -293,7 +293,7 @@ def read_item( # pylint:disable=docstring-missing-param if max_integrated_cache_staleness_in_ms is not None: validate_cache_staleness_value(max_integrated_cache_staleness_in_ms) request_options["maxIntegratedCacheStaleness"] = max_integrated_cache_staleness_in_ms - self._get_properties() + self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] return self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) @@ -357,7 +357,7 @@ def read_all_items( # pylint:disable=docstring-missing-param if response_hook and hasattr(response_hook, "clear"): response_hook.clear() - self._get_properties() + self._get_properties_with_options(feed_options) feed_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] items = self.client_connection.ReadItems( @@ -576,7 +576,7 @@ def query_items_change_feed( elif "start_time" in kwargs: change_feed_state_context["startTime"] = kwargs.pop("start_time") - container_properties = self._get_properties_with_feed_options(feed_options) + container_properties = self._get_properties_with_options(feed_options) if "partition_key" in kwargs: partition_key = kwargs.pop("partition_key") change_feed_state_context["partitionKey"] = self._set_partition_key(cast(PartitionKeyType, partition_key)) @@ -698,10 +698,9 @@ def query_items( # pylint:disable=docstring-missing-param feed_options["populateQueryMetrics"] = populate_query_metrics if populate_index_metrics is not None: feed_options["populateIndexMetrics"] = populate_index_metrics - properties = self._get_properties() + properties = self._get_properties_with_options(feed_options) if partition_key is not None: partition_key_value = self._set_partition_key(partition_key) - properties = self._get_properties_with_feed_options(feed_options) if is_prefix_partition_key(properties, partition_key): kwargs["isPrefixPartitionQuery"] = True kwargs["partitionKeyDefinition"] = properties["partitionKey"] @@ -813,7 +812,7 @@ def replace_item( # pylint:disable=docstring-missing-param ) request_options["populateQueryMetrics"] = populate_query_metrics - self._get_properties() + self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.ReplaceItem( document_link=item_link, @@ -899,7 +898,7 @@ def upsert_item( # pylint:disable=docstring-missing-param DeprecationWarning, ) request_options["populateQueryMetrics"] = populate_query_metrics - self._get_properties() + self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.UpsertItem( @@ -999,7 +998,7 @@ def create_item( # pylint:disable=docstring-missing-param request_options["populateQueryMetrics"] = populate_query_metrics if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive - self._get_properties() + self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] result = self.client_connection.CreateItem( database_or_container_link=self.container_link, document=body, options=request_options, **kwargs) @@ -1085,7 +1084,7 @@ def patch_item( if filter_predicate is not None: request_options["filterPredicate"] = filter_predicate - self._get_properties() + self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] item_link = self._get_document_link(item) result = self.client_connection.PatchItem( @@ -1160,7 +1159,7 @@ def execute_item_batch( request_options = build_options(kwargs) request_options["partitionKey"] = self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True - container_properties = self._get_properties() + container_properties = self._get_properties_with_options(request_options) request_options["containerRID"] = container_properties["_rid"] return self.client_connection.Batch( @@ -1240,7 +1239,7 @@ def delete_item( # pylint:disable=docstring-missing-param request_options["preTriggerInclude"] = pre_trigger_include if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include - self._get_properties() + self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] document_link = self._get_document_link(item) self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs) @@ -1520,7 +1519,7 @@ def delete_all_items_by_partition_key( request_options = build_options(kwargs) # regardless if partition key is valid we set it as invalid partition keys are set to a default empty value request_options["partitionKey"] = self._set_partition_key(partition_key) - self._get_properties() + self._get_properties_with_options(request_options) request_options["containerRID"] = self.__get_client_container_caches()[self.container_link]["_rid"] self.client_connection.DeleteAllItemsByPartitionKey( diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index e5f7a4c281d5..829f6a163745 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import logging -import re import unittest import uuid import test_config @@ -10,7 +9,6 @@ import time from azure.cosmos import CosmosClient -from azure.cosmos.http_constants import ResourceType class MockHandler(logging.Handler): @@ -110,12 +108,12 @@ def read_item_test_data(): [L1], # 3 ] client_and_request_output_data = [ - [L2, L1], # 0 + [L2], # 0 [L2], # 1 - [L2, L1], # 2 - [L1, L2], # 3 + [L2], # 2 + [L1], # 3 [L1], # 4 - [L1, L2], # 5 + [L1], # 5 [L1], # 6 [L1], # 7 ] @@ -354,7 +352,7 @@ def test_create_item(self, test_data): # Single write verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) def test_patch_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -379,7 +377,7 @@ def test_patch_item(self, test_data): # get location from mock_handler verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) def test_execute_item_batch(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 8210106a5e5a..1b2928de217e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -252,7 +252,7 @@ async def test_create_item(self, test_data): # get location from mock_handler verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) async def test_patch_item_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -280,7 +280,7 @@ async def test_patch_item_async(self, test_data): # get location from mock_handler verify_endpoint(MOCK_HANDLER.messages, client, expected_locations, multiple_write_locations) - @pytest.mark.parametrize('test_data', write_item_test_data()) + @pytest.mark.parametrize('test_data', read_and_write_item_test_data()) async def test_execute_item_batch_async(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 79f50fa2a530..76e8e1d85ffb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -191,6 +191,7 @@ class TestPerPartitionCircuitBreakerMMAsync: TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_MULTI_PARTITION_CONTAINER_ID async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(default_endpoint, self.master_key, preferred_locations=[REGION_1, REGION_2], multiple_write_locations=True, From 5f8b7e8c202bb7ae1d45ae64fbfafc1212d0fb1e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 27 May 2025 17:08:45 -0700 Subject: [PATCH 151/152] react to comments and fix tests --- .../azure-cosmos/azure/cosmos/aio/_retry_utility_async.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 20d048b3cef6..33b9c0785b38 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -46,6 +46,10 @@ # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches +# args [0] is the request object +# args [1] is the connection policy +# args [2] is the pipeline client +# args [3] is the http request async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwargs): # pylint: disable=too-many-locals """Executes the function with passed parameters applying all retry policies From 3010aa505efff3436b67fa74976f98042b667e4c Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 28 May 2025 10:19:49 -0700 Subject: [PATCH 152/152] fix tests --- sdk/cosmos/azure-cosmos/tests/test_crud_async.py | 2 ++ .../tests/test_per_partition_circuit_breaker_mm_async.py | 6 ++---- .../azure-cosmos/tests/test_query_hybrid_search_async.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 40f4ff1d422c..7124c16e88b0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -10,6 +10,7 @@ import unittest import urllib.parse as urllib import uuid +from asyncio import sleep import pytest import requests @@ -1130,6 +1131,7 @@ async def test_get_resource_with_dictionary_and_object_async(self): assert read_container.id == created_container.id created_item = await created_container.create_item({'id': '1' + str(uuid.uuid4()), 'pk': 'pk'}) + await sleep(5) # read item with id read_item = await created_container.read_item(item=created_item['id'], partition_key=created_item['pk']) diff --git a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py index 76e8e1d85ffb..095bd25bbb8b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_per_partition_circuit_breaker_mm_async.py @@ -468,7 +468,7 @@ async def test_service_request_error_async(self, read_operation, write_operation await cleanup_method([custom_setup, setup]) - # send 5 write concurrent requests when trying to recover + # send 15 write concurrent requests when trying to recover # verify that only one failed async def test_recovering_only_fails_one_requests_async(self): error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( @@ -500,7 +500,7 @@ async def concurrent_upsert(): _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = 1 try: tasks = [] - for i in range(10): + for i in range(15): tasks.append(concurrent_upsert()) await asyncio.gather(*tasks) assert number_of_errors == 1 @@ -508,7 +508,5 @@ async def concurrent_upsert(): _partition_health_tracker.INITIAL_UNAVAILABLE_TIME = original_unavailable_time await cleanup_method([custom_setup, setup]) - # test cosmos client timeout - if __name__ == '__main__': unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py index 86f33f658204..d779ac2472e0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py @@ -295,7 +295,7 @@ async def test_hybrid_search_weighted_reciprocal_rank_fusion_async(self): query = "SELECT c.index, c.title FROM c " \ "ORDER BY RANK RRF(FullTextScore(c.text, 'United States'), VectorDistance(c.vector, {}), [1,1]) " \ "OFFSET 0 LIMIT 10".format(item_vector) - results = self.test_container.query_items(query, enable_cross_partition_query=True) + results = self.test_container.query_items(query) result_list = [res async for res in results] assert len(result_list) == 10 result_list = [res['index'] for res in result_list]