diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py index 88e83c40be..3707bc1f3e 100644 --- a/google/cloud/aiplatform/metadata/context.py +++ b/google/cloud/aiplatform/metadata/context.py @@ -18,6 +18,7 @@ from typing import Optional, Dict, List, Sequence import proto +import re import threading from google.auth import credentials as auth_credentials @@ -37,6 +38,12 @@ from google.cloud.aiplatform.metadata import execution from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform.metadata import resource +from google.api_core.exceptions import Aborted + +_ETAG_ERROR_MAX_RETRY_COUNT = 5 +_ETAG_ERROR_REGEX = re.compile( + r"Specified Context \`etag\`: \`(\d+)\` does not match server \`etag\`: \`(\d+)\`" +) class Context(resource._Resource): @@ -278,6 +285,46 @@ def _create_resource( context_id=resource_id, ) + def update( + self, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Updates an existing Metadata Context with new metadata. + + This is implemented with retry on etag errors, up to + _ETAG_ERROR_MAX_RETRY_COUNT times. + Args: + metadata (Dict): + Optional. metadata contains the updated metadata information. + description (str): + Optional. Description describes the resource to be updated. + credentials (auth_credentials.Credentials): + Custom credentials to use to update this resource. Overrides + credentials set in aiplatform.init. + """ + for _ in range(_ETAG_ERROR_MAX_RETRY_COUNT - 1): + try: + super().update( + metadata=metadata, description=description, credentials=credentials + ) + return + except Aborted as aborted_exception: + regex_match = _ETAG_ERROR_REGEX.match(aborted_exception.message) + if regex_match: + local_etag = regex_match.group(1) + server_etag = regex_match.group(2) + if local_etag < server_etag: + self.sync_resource() + continue + raise aborted_exception + + # Expose result/exception directly in the last retry. + super().update( + metadata=metadata, description=description, credentials=credentials + ) + @classmethod def _update_resource( cls, diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py index ee966a2852..7c3f5269f8 100644 --- a/tests/unit/aiplatform/test_metadata_resources.py +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -153,6 +153,53 @@ def update_context_mock(): yield update_context_mock +@pytest.fixture +def update_context_with_errors_mock(): + with patch.object( + MetadataServiceClient, "update_context" + ) as update_context_with_errors_mock: + update_context_with_errors_mock.side_effect = [ + exceptions.Aborted( + "Specified Context `etag`: `1` does not match server `etag`: `2`" + ), + GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ), + ] + yield update_context_with_errors_mock + + +@pytest.fixture +def update_context_with_errors_mock_2(): + with patch.object( + MetadataServiceClient, "update_context" + ) as update_context_with_errors_mock_2: + update_context_with_errors_mock_2.side_effect = [ + exceptions.Aborted( + "Specified Context `etag`: `2` does not match server `etag`: `1`" + ) + ] + yield update_context_with_errors_mock_2 + + +@pytest.fixture +def update_context_with_errors_mock_3(): + with patch.object( + MetadataServiceClient, "update_context" + ) as update_context_with_errors_mock_3: + update_context_with_errors_mock_3.side_effect = [ + exceptions.Aborted( + "Specified Context `etag`: `1` does not match server `etag`: `2`" + ) + ] * 6 + yield update_context_with_errors_mock_2 + + @pytest.fixture def add_context_artifacts_and_executions_mock(): with patch.object( @@ -482,6 +529,70 @@ def test_update_context(self, update_context_mock): update_context_mock.assert_called_once_with(context=updated_context) assert my_context._gca_resource == updated_context + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("create_context_mock") + def test_update_context_with_retry_success(self, update_context_with_errors_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context.Context._create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_context.update(_TEST_UPDATED_METADATA) + + updated_context = GapicContext( + name=_TEST_CONTEXT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + + update_context_with_errors_mock.assert_called_with(context=updated_context) + assert my_context._gca_resource == updated_context + + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("create_context_mock") + @pytest.mark.usefixtures("update_context_with_errors_mock_2") + def test_update_context_with_retry_etag_order_failure(self): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context.Context._create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + with pytest.raises(exceptions.Aborted): + my_context.update(_TEST_UPDATED_METADATA) + + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("create_context_mock") + @pytest.mark.usefixtures("update_context_with_errors_mock_3") + def test_update_context_with_retry_too_many_error_failure(self): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context.Context._create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + with pytest.raises(exceptions.Aborted): + my_context.update(_TEST_UPDATED_METADATA) + @pytest.mark.usefixtures("get_context_mock") def test_list_contexts(self, list_contexts_mock): aiplatform.init(project=_TEST_PROJECT)