Skip to content

Commit

Permalink
feat: Add update_version to Model Registry
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506139032
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jan 31, 2023
1 parent 7ab6e0b commit 8621e24
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
54 changes: 54 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4905,6 +4905,60 @@ def delete_version(

_LOGGER.info(f"Deleted version {version} for {self.model_resource_name}")

def update_version(
self,
version: str,
version_description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
) -> None:
"""Updates a model version.
Args:
version (str): Required. The version ID to receive the new alias(es).
version_description (str):
The description of the model version.
labels (Dict[str, str]):
Optional. The labels with user-defined metadata to
organize your Model versions.
Label keys and values can be no longer than 64
characters (Unicode codepoints), can only
contain lowercase letters, numeric characters,
underscores and dashes. International characters
are allowed.
See https://goo.gl/xmQnxf for more information
and examples of labels.
Raises:
ValueError: If `labels` is not the correct format.
"""

current_model_proto = self.get_model(version).gca_resource
copied_model_proto = current_model_proto.__class__(current_model_proto)

update_mask: List[str] = []

if version_description:
copied_model_proto.version_description = version_description
update_mask.append("version_description")

if labels:
utils.validate_labels(labels)

copied_model_proto.labels = labels
update_mask.append("labels")

update_mask = field_mask_pb2.FieldMask(paths=update_mask)
versioned_name = self._get_versioned_name(self.model_resource_name, version)

_LOGGER.info(f"Updating model {versioned_name}")

self.client.update_model(
model=copied_model_proto,
update_mask=update_mask,
)

_LOGGER.info(f"Completed updating model {versioned_name}")

def add_version_aliases(
self,
new_aliases: List[str],
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,27 @@ def test_delete_version(self, delete_model_version_mock, get_model_with_version)
)
)

@pytest.mark.usefixtures("get_model_mock")
def test_update_version(
self, update_model_mock, get_model_mock, get_model_with_version
):
my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION)
my_model.versioning_registry.update_version(
_TEST_VERSION_ALIAS_1,
version_description="update version",
labels=_TEST_LABEL,
)

model_to_update = _TEST_MODEL_OBJ_WITH_VERSION
model_to_update.version_description = "update version"
model_to_update.labels = _TEST_LABEL

update_mask = field_mask_pb2.FieldMask(paths=["version_description", "labels"])

update_model_mock.assert_called_once_with(
model=model_to_update, update_mask=update_mask
)

def test_add_versions(self, merge_version_aliases_mock, get_model_with_version):
my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION)
my_model.versioning_registry.add_version_aliases(
Expand Down

0 comments on commit 8621e24

Please sign in to comment.