From 8621e24cd02cb545e353f54562bf111616d7a9f2 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 31 Jan 2023 15:46:06 -0800 Subject: [PATCH] feat: Add update_version to Model Registry PiperOrigin-RevId: 506139032 --- google/cloud/aiplatform/models.py | 54 ++++++++++++++++++++++++++++ tests/unit/aiplatform/test_models.py | 21 +++++++++++ 2 files changed, 75 insertions(+) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 47b46ed209..8bca932103 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -4905,6 +4905,60 @@ def delete_version( _LOGGER.info(f"Deleted version {version} for {self.model_resource_name}") + def update_version( + self, + version: str, + version_description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + ) -> None: + """Updates a model version. + + Args: + version (str): Required. The version ID to receive the new alias(es). + version_description (str): + The description of the model version. + labels (Dict[str, str]): + Optional. The labels with user-defined metadata to + organize your Model versions. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + + Raises: + ValueError: If `labels` is not the correct format. + """ + + current_model_proto = self.get_model(version).gca_resource + copied_model_proto = current_model_proto.__class__(current_model_proto) + + update_mask: List[str] = [] + + if version_description: + copied_model_proto.version_description = version_description + update_mask.append("version_description") + + if labels: + utils.validate_labels(labels) + + copied_model_proto.labels = labels + update_mask.append("labels") + + update_mask = field_mask_pb2.FieldMask(paths=update_mask) + versioned_name = self._get_versioned_name(self.model_resource_name, version) + + _LOGGER.info(f"Updating model {versioned_name}") + + self.client.update_model( + model=copied_model_proto, + update_mask=update_mask, + ) + + _LOGGER.info(f"Completed updating model {versioned_name}") + def add_version_aliases( self, new_aliases: List[str], diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 1c9e7dd4a0..7df01ecd82 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -2714,6 +2714,27 @@ def test_delete_version(self, delete_model_version_mock, get_model_with_version) ) ) + @pytest.mark.usefixtures("get_model_mock") + def test_update_version( + self, update_model_mock, get_model_mock, get_model_with_version + ): + my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION) + my_model.versioning_registry.update_version( + _TEST_VERSION_ALIAS_1, + version_description="update version", + labels=_TEST_LABEL, + ) + + model_to_update = _TEST_MODEL_OBJ_WITH_VERSION + model_to_update.version_description = "update version" + model_to_update.labels = _TEST_LABEL + + update_mask = field_mask_pb2.FieldMask(paths=["version_description", "labels"]) + + update_model_mock.assert_called_once_with( + model=model_to_update, update_mask=update_mask + ) + def test_add_versions(self, merge_version_aliases_mock, get_model_with_version): my_model = models.Model(_TEST_MODEL_NAME, _TEST_PROJECT, _TEST_LOCATION) my_model.versioning_registry.add_version_aliases(