Skip to content

Commit

Permalink
fix: Retry for etag errors on context update.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 537445290
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jun 3, 2023
1 parent 635ae9c commit d3d5f9a
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
47 changes: 47 additions & 0 deletions google/cloud/aiplatform/metadata/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, Dict, List, Sequence

import proto
import re
import threading

from google.auth import credentials as auth_credentials
Expand All @@ -37,6 +38,12 @@
from google.cloud.aiplatform.metadata import execution
from google.cloud.aiplatform.metadata import metadata_store
from google.cloud.aiplatform.metadata import resource
from google.api_core.exceptions import Aborted

_ETAG_ERROR_MAX_RETRY_COUNT = 5
_ETAG_ERROR_REGEX = re.compile(
r"Specified Context \`etag\`: \`(\d+)\` does not match server \`etag\`: \`(\d+)\`"
)


class Context(resource._Resource):
Expand Down Expand Up @@ -278,6 +285,46 @@ def _create_resource(
context_id=resource_id,
)

def update(
self,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Updates an existing Metadata Context with new metadata.
This is implemented with retry on etag errors, up to
_ETAG_ERROR_MAX_RETRY_COUNT times.
Args:
metadata (Dict):
Optional. metadata contains the updated metadata information.
description (str):
Optional. Description describes the resource to be updated.
credentials (auth_credentials.Credentials):
Custom credentials to use to update this resource. Overrides
credentials set in aiplatform.init.
"""
for _ in range(_ETAG_ERROR_MAX_RETRY_COUNT - 1):
try:
super().update(
metadata=metadata, description=description, credentials=credentials
)
return
except Aborted as aborted_exception:
regex_match = _ETAG_ERROR_REGEX.match(aborted_exception.message)
if regex_match:
local_etag = regex_match.group(1)
server_etag = regex_match.group(2)
if local_etag < server_etag:
self.sync_resource()
continue
raise aborted_exception

# Expose result/exception directly in the last retry.
super().update(
metadata=metadata, description=description, credentials=credentials
)

@classmethod
def _update_resource(
cls,
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/aiplatform/test_metadata_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,53 @@ def update_context_mock():
yield update_context_mock


@pytest.fixture
def update_context_with_errors_mock():
with patch.object(
MetadataServiceClient, "update_context"
) as update_context_with_errors_mock:
update_context_with_errors_mock.side_effect = [
exceptions.Aborted(
"Specified Context `etag`: `1` does not match server `etag`: `2`"
),
GapicContext(
name=_TEST_CONTEXT_NAME,
display_name=_TEST_DISPLAY_NAME,
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
description=_TEST_DESCRIPTION,
metadata=_TEST_UPDATED_METADATA,
),
]
yield update_context_with_errors_mock


@pytest.fixture
def update_context_with_errors_mock_2():
with patch.object(
MetadataServiceClient, "update_context"
) as update_context_with_errors_mock_2:
update_context_with_errors_mock_2.side_effect = [
exceptions.Aborted(
"Specified Context `etag`: `2` does not match server `etag`: `1`"
)
]
yield update_context_with_errors_mock_2


@pytest.fixture
def update_context_with_errors_mock_3():
with patch.object(
MetadataServiceClient, "update_context"
) as update_context_with_errors_mock_3:
update_context_with_errors_mock_3.side_effect = [
exceptions.Aborted(
"Specified Context `etag`: `1` does not match server `etag`: `2`"
)
] * 6
yield update_context_with_errors_mock_2


@pytest.fixture
def add_context_artifacts_and_executions_mock():
with patch.object(
Expand Down Expand Up @@ -482,6 +529,70 @@ def test_update_context(self, update_context_mock):
update_context_mock.assert_called_once_with(context=updated_context)
assert my_context._gca_resource == updated_context

@pytest.mark.usefixtures("get_context_mock")
@pytest.mark.usefixtures("create_context_mock")
def test_update_context_with_retry_success(self, update_context_with_errors_mock):
aiplatform.init(project=_TEST_PROJECT)

my_context = context.Context._create(
resource_id=_TEST_CONTEXT_ID,
schema_title=_TEST_SCHEMA_TITLE,
display_name=_TEST_DISPLAY_NAME,
schema_version=_TEST_SCHEMA_VERSION,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
metadata_store_id=_TEST_METADATA_STORE,
)
my_context.update(_TEST_UPDATED_METADATA)

updated_context = GapicContext(
name=_TEST_CONTEXT_NAME,
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
metadata=_TEST_UPDATED_METADATA,
)

update_context_with_errors_mock.assert_called_with(context=updated_context)
assert my_context._gca_resource == updated_context

@pytest.mark.usefixtures("get_context_mock")
@pytest.mark.usefixtures("create_context_mock")
@pytest.mark.usefixtures("update_context_with_errors_mock_2")
def test_update_context_with_retry_etag_order_failure(self):
aiplatform.init(project=_TEST_PROJECT)

my_context = context.Context._create(
resource_id=_TEST_CONTEXT_ID,
schema_title=_TEST_SCHEMA_TITLE,
display_name=_TEST_DISPLAY_NAME,
schema_version=_TEST_SCHEMA_VERSION,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
metadata_store_id=_TEST_METADATA_STORE,
)
with pytest.raises(exceptions.Aborted):
my_context.update(_TEST_UPDATED_METADATA)

@pytest.mark.usefixtures("get_context_mock")
@pytest.mark.usefixtures("create_context_mock")
@pytest.mark.usefixtures("update_context_with_errors_mock_3")
def test_update_context_with_retry_too_many_error_failure(self):
aiplatform.init(project=_TEST_PROJECT)

my_context = context.Context._create(
resource_id=_TEST_CONTEXT_ID,
schema_title=_TEST_SCHEMA_TITLE,
display_name=_TEST_DISPLAY_NAME,
schema_version=_TEST_SCHEMA_VERSION,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
metadata_store_id=_TEST_METADATA_STORE,
)
with pytest.raises(exceptions.Aborted):
my_context.update(_TEST_UPDATED_METADATA)

@pytest.mark.usefixtures("get_context_mock")
def test_list_contexts(self, list_contexts_mock):
aiplatform.init(project=_TEST_PROJECT)
Expand Down

0 comments on commit d3d5f9a

Please sign in to comment.