diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index d9328b042d..3c009ae13d 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -114,7 +114,7 @@ def mock_create_image_dataset(mock_image_dataset): @pytest.fixture def mock_create_tabular_dataset(mock_tabular_dataset): with patch.object( - aiplatform.TabularDataset, "create" + aiplatform.TabularDataset, "create" ) as mock_create_tabular_dataset: mock_create_tabular_dataset.return_value = mock_tabular_dataset yield mock_create_tabular_dataset @@ -123,7 +123,7 @@ def mock_create_tabular_dataset(mock_tabular_dataset): @pytest.fixture def mock_create_time_series_dataset(mock_time_series_dataset): with patch.object( - aiplatform.TimeSeriesDataset, "create" + aiplatform.TimeSeriesDataset, "create" ) as mock_create_time_series_dataset: mock_create_time_series_dataset.return_value = mock_time_series_dataset yield mock_create_time_series_dataset @@ -462,7 +462,7 @@ def mock_get_entity_type(mock_entity_type): @pytest.fixture def mock_create_featurestore(mock_featurestore): with patch.object( - aiplatform.featurestore.Featurestore, "create" + aiplatform.featurestore.Featurestore, "create" ) as mock_create_featurestore: mock_create_featurestore.return_value = mock_featurestore yield mock_create_featurestore @@ -471,7 +471,7 @@ def mock_create_featurestore(mock_featurestore): @pytest.fixture def mock_create_entity_type(mock_entity_type): with patch.object( - aiplatform.featurestore.EntityType, "create" + aiplatform.featurestore.EntityType, "create" ) as mock_create_entity_type: mock_create_entity_type.return_value = mock_entity_type yield mock_create_entity_type @@ -499,7 +499,7 @@ def mock_batch_serve_to_bq(mock_featurestore): @pytest.fixture def mock_batch_create_features(mock_entity_type): with patch.object( - mock_entity_type, "batch_create_features" + mock_entity_type, "batch_create_features" ) as mock_batch_create_features: yield mock_batch_create_features @@ -513,7 +513,7 @@ def mock_read_feature_values(mock_entity_type): @pytest.fixture def mock_import_feature_values(mock_entity_type): with patch.object( - mock_entity_type, "ingest_from_gcs" + mock_entity_type, "ingest_from_gcs" ) as mock_import_feature_values: yield mock_import_feature_values @@ -644,7 +644,7 @@ def mock_context_list(mock_context): @pytest.fixture def mock_create_schema_base_context(mock_context): with patch.object( - aiplatform.metadata.schema.base_context.BaseContextSchema, "create" + aiplatform.metadata.schema.base_context.BaseContextSchema, "create" ) as mock_create_schema_base_context: mock_create_schema_base_context.return_value = mock_context yield mock_create_schema_base_context @@ -702,7 +702,7 @@ def mock_create_artifact(mock_artifact): @pytest.fixture def mock_create_schema_base_artifact(mock_artifact): with patch.object( - aiplatform.metadata.schema.base_artifact.BaseArtifactSchema, "create" + aiplatform.metadata.schema.base_artifact.BaseArtifactSchema, "create" ) as mock_create_schema_base_artifact: mock_create_schema_base_artifact.return_value = mock_artifact yield mock_create_schema_base_artifact @@ -711,7 +711,7 @@ def mock_create_schema_base_artifact(mock_artifact): @pytest.fixture def mock_create_schema_base_execution(mock_execution): with patch.object( - aiplatform.metadata.schema.base_execution.BaseExecutionSchema, "create" + aiplatform.metadata.schema.base_execution.BaseExecutionSchema, "create" ) as mock_create_schema_base_execution: mock_create_schema_base_execution.return_value = mock_execution yield mock_create_schema_base_execution @@ -757,7 +757,7 @@ def mock_log_metrics(): @pytest.fixture def mock_log_time_series_metrics(): with patch.object( - aiplatform, "log_time_series_metrics" + aiplatform, "log_time_series_metrics" ) as mock_log_time_series_metrics: mock_log_time_series_metrics.return_value = None yield mock_log_time_series_metrics @@ -822,7 +822,75 @@ def mock_get_params(mock_params, mock_experiment_run): @pytest.fixture def mock_get_time_series_metrics(mock_time_series_metrics, mock_experiment_run): with patch.object( - mock_experiment_run, "get_time_series_data_frame" + mock_experiment_run, "get_time_series_data_frame" ) as mock_get_time_series_metrics: mock_get_time_series_metrics.return_value = mock_time_series_metrics yield mock_get_time_series_metrics + + +""" +---------------------------------------------------------------------------- +Model Versioning Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_model_registry(): + mock = MagicMock(aiplatform.models.ModelRegistry) + yield mock + + +@pytest.fixture +def mock_version_info(): + mock = MagicMock(aiplatform.models.VersionInfo) + yield mock + + +@pytest.fixture +def mock_init_model_registry(mock_model_registry): + with patch.object(aiplatform.models, "ModelRegistry") as mock: + mock.return_value = mock_model_registry + yield mock + + +@pytest.fixture +def mock_get_model(mock_model_registry): + with patch.object(mock_model_registry, "get_model") as mock_get_model: + mock_get_model.return_value = mock_model + yield mock_get_model + + +@pytest.fixture +def mock_get_model_version_info(mock_model_registry): + with patch.object(mock_model_registry, "get_version_info") as mock_get_model_version_info: + mock_get_model_version_info.return_value = mock_version_info + yield mock_get_model_version_info + + +@pytest.fixture +def mock_list_versions(mock_model_registry, mock_version_info): + with patch.object(mock_model_registry, "list_versions") as mock_list_versions: + mock_list_versions.return_value = [mock_version_info, mock_version_info] + yield mock_list_versions + + +@pytest.fixture +def mock_delete_version(mock_model_registry): + with patch.object(mock_model_registry, "delete_version") as mock_delete_version: + mock_delete_version.return_value = None + yield mock_delete_version + + +@pytest.fixture +def mock_add_version_aliases(mock_model_registry): + with patch.object(mock_model_registry, "add_version_aliases") as mock_add_version_aliases: + mock_add_version_aliases.return_value = None + yield mock_add_version_aliases + + +@pytest.fixture +def mock_remove_version_aliases(mock_model_registry): + with patch.object(mock_model_registry, "remove_version_aliases") as mock_remove_version_aliases: + mock_remove_version_aliases.return_value = None + yield mock_remove_version_aliases diff --git a/samples/model-builder/model_registry/assign_aliases_model_version_sample.py b/samples/model-builder/model_registry/assign_aliases_model_version_sample.py new file mode 100644 index 0000000000..e9b1d3178e --- /dev/null +++ b/samples/model-builder/model_registry/assign_aliases_model_version_sample.py @@ -0,0 +1,51 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_assign_aliases_model_version_sample] + +from typing import List + +from google.cloud import aiplatform + + +def assign_aliases_model_version_sample( + model_id: str, + version_aliases: List[str], + version_id: str, + project: str, + location: str, +): + """ + Assign aliases to a model version. + Args: + model_id: The ID of the model. + version_aliases: The version aliases to assign. + version_id: The version ID of the model to assign the aliases to. + project: The project name. + location: The location name. + Returns + None. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model Registry resource with the ID 'model_id'.The parent_name of create method can be also + # 'projects//locations//models/' + model_registry = aiplatform.models.ModelRegistry(model=model_id) + + # Assign the version aliases to the model with the version 'version_id'. + model_registry.add_version_aliases(new_aliases=version_aliases, version=version_id) + + +# [END aiplatform_model_registry_assign_aliases_model_version_sample] diff --git a/samples/model-builder/model_registry/assign_aliases_model_version_sample_test.py b/samples/model-builder/model_registry/assign_aliases_model_version_sample_test.py new file mode 100644 index 0000000000..7a3f3a3356 --- /dev/null +++ b/samples/model-builder/model_registry/assign_aliases_model_version_sample_test.py @@ -0,0 +1,46 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import assign_aliases_model_version_sample + +import test_constants as constants + + +def test_assign_aliases_model_version_sample( + mock_sdk_init, mock_init_model_registry, mock_add_version_aliases, mock_model +): + + # Assign aliases to a model version. + assign_aliases_model_version_sample.assign_aliases_model_version_sample( + model_id=constants.MODEL_NAME, + version_id=constants.VERSION_ID, + version_aliases=constants.VERSION_ALIASES, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check model registry initialization. + mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME) + + # Check that the model version was assigned the aliases. + mock_add_version_aliases.assert_called_with( + new_aliases=constants.VERSION_ALIASES, + version=constants.VERSION_ID, + ) diff --git a/samples/model-builder/model_registry/create_aliased_model_sample.py b/samples/model-builder/model_registry/create_aliased_model_sample.py new file mode 100644 index 0000000000..f66b2dd6f9 --- /dev/null +++ b/samples/model-builder/model_registry/create_aliased_model_sample.py @@ -0,0 +1,45 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_create_aliased_model_sample] + +from google.cloud import aiplatform + + +def create_aliased_model_sample( + model_id: str, version_id: str, project: str, location: str +): + """ + Initialize a Model resource to represent an existing model version with custom alias. + Args: + model_id: The ID of the model to initialize. Parent resource name of the model is also accepted. + version_id: The version ID or version alias of the model to initialize. + project: The project. + location: The location. + Returns: + Model resource. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model resource with the ID 'model_id'. The version can be also provided using @ annotation in + # the parent resource name: + # 'projects//locations//models/@'. + + aliased_model = aiplatform.Model(model_name=model_id, version=version_id) + + return aliased_model + + +# [END aiplatform_model_registry_create_aliased_model_sample] diff --git a/samples/model-builder/model_registry/create_aliased_model_sample_test.py b/samples/model-builder/model_registry/create_aliased_model_sample_test.py new file mode 100644 index 0000000000..38bd8150bb --- /dev/null +++ b/samples/model-builder/model_registry/create_aliased_model_sample_test.py @@ -0,0 +1,38 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_aliased_model_sample + +import test_constants as constants + + +def test_create_aliased_model_sample(mock_sdk_init, mock_init_model): + # Create a model with alias 'default'. + create_aliased_model_sample.create_aliased_model_sample( + model_id=constants.MODEL_NAME, + version_id=constants.VERSION_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check that the model was created. + mock_init_model.assert_called_with( + model_name=constants.MODEL_NAME, version=constants.VERSION_ID + ) diff --git a/samples/model-builder/model_registry/create_default_model_sample.py b/samples/model-builder/model_registry/create_default_model_sample.py new file mode 100644 index 0000000000..9228e03a6a --- /dev/null +++ b/samples/model-builder/model_registry/create_default_model_sample.py @@ -0,0 +1,40 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_create_default_model_sample] + +from google.cloud import aiplatform + + +def create_default_model_sample(model_id: str, project: str, location: str): + """ + Initialize a Model resource to represent an existing model version with alias 'default'. + Args: + model_id: The ID of the model to initialize. Parent resource name of the model is also accepted. + project: The project. + location: The location. + Returns: + Model resource. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model resource with the ID 'model_id'. The parent_name of create method can be also + # 'projects//locations//models/' + default_model = aiplatform.Model(model_name=model_id) + + return default_model + + +# [END aiplatform_model_registry_create_default_model_sample] diff --git a/samples/model-builder/model_registry/create_default_model_sample_test.py b/samples/model-builder/model_registry/create_default_model_sample_test.py new file mode 100644 index 0000000000..2b3345e53d --- /dev/null +++ b/samples/model-builder/model_registry/create_default_model_sample_test.py @@ -0,0 +1,35 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_default_model_sample + +import test_constants as constants + + +def test_create_default_model_sample(mock_sdk_init, mock_init_model): + # Create a model with alias 'default'. + create_default_model_sample.create_default_model_sample( + model_id=constants.MODEL_NAME, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check that the model was created. + mock_init_model.assert_called_with(model_name=constants.MODEL_NAME) diff --git a/samples/model-builder/model_registry/create_model_registry_sample.py b/samples/model-builder/model_registry/create_model_registry_sample.py new file mode 100644 index 0000000000..8b853eaff5 --- /dev/null +++ b/samples/model-builder/model_registry/create_model_registry_sample.py @@ -0,0 +1,41 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_create_model_registry_sample] + +from google.cloud import aiplatform + + +def create_model_registry_sample(model_id: str, project: str, location: str): + """ + Create a ModelRegistry resource associated to model_id + Args: + model_id: The ID of the model. + project: The project name. + location: The location name. + Returns: + ModelRegistry resource. + """ + + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model Registry resource with the ID 'model_id'.The parent_name of create method can be also + # 'projects//locations//models/' + model_registry = aiplatform.models.ModelRegistry(model=model_id) + + return model_registry + + +# [END aiplatform_model_registry_create_model_registry_sample] diff --git a/samples/model-builder/model_registry/create_model_registry_sample_test.py b/samples/model-builder/model_registry/create_model_registry_sample_test.py new file mode 100644 index 0000000000..53247711fc --- /dev/null +++ b/samples/model-builder/model_registry/create_model_registry_sample_test.py @@ -0,0 +1,35 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_model_registry_sample + +import test_constants as constants + + +def test_create_model_registry_sample(mock_sdk_init, mock_init_model_registry): + # Create a model registry. + create_model_registry_sample.create_model_registry_sample( + model_id=constants.MODEL_NAME, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check that the model registry was created. + mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME) diff --git a/samples/model-builder/model_registry/delete_aliases_model_version_sample.py b/samples/model-builder/model_registry/delete_aliases_model_version_sample.py new file mode 100644 index 0000000000..6a271d096e --- /dev/null +++ b/samples/model-builder/model_registry/delete_aliases_model_version_sample.py @@ -0,0 +1,53 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_delete_aliases_model_version_sample] + +from typing import List + +from google.cloud import aiplatform + + +def delete_aliases_model_version_sample( + model_id: str, + version_aliases: List[str], + version_id: str, + project: str, + location: str, +): + """ + Delete aliases to a model version. + Args: + model_id: The ID of the model. + version_aliases: The version aliases to assign. + version_id: The version ID of the model to assign the aliases to. + project: The project name. + location: The location name. + Returns + None. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model Registry resource with the ID 'model_id'.The parent_name of create method can be also + # 'projects//locations//models/' + model_registry = aiplatform.models.ModelRegistry(model=model_id) + + # Remove the version aliases to the model version with the version 'version'. + model_registry.remove_version_aliases( + target_aliases=version_aliases, version=version_id + ) + + +# [END aiplatform_model_registry_delete_aliases_model_version_sample] diff --git a/samples/model-builder/model_registry/delete_aliases_model_version_sample_test.py b/samples/model-builder/model_registry/delete_aliases_model_version_sample_test.py new file mode 100644 index 0000000000..889ea2135c --- /dev/null +++ b/samples/model-builder/model_registry/delete_aliases_model_version_sample_test.py @@ -0,0 +1,42 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import delete_aliases_model_version_sample + +import test_constants as constants + + +def test_delete_aliases_model_version_sample( + mock_sdk_init, mock_init_model_registry, mock_remove_version_aliases +): + # Delete aliases from a model version. + delete_aliases_model_version_sample.delete_aliases_model_version_sample( + model_id=constants.MODEL_NAME, + version_aliases=constants.VERSION_ALIASES, + version_id=constants.VERSION_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check model registry initialization. + mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME) + + # Check that the model version was removed the aliases. + mock_remove_version_aliases.assert_called_once() diff --git a/samples/model-builder/model_registry/delete_model_sample.py b/samples/model-builder/model_registry/delete_model_sample.py new file mode 100644 index 0000000000..ce43255bb1 --- /dev/null +++ b/samples/model-builder/model_registry/delete_model_sample.py @@ -0,0 +1,41 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_delete_model_sample] + +from google.cloud import aiplatform + + +def delete_model_sample(model_id: str, project: str, location: str): + """ + Delete a Model resource. + Args: + model_id: The ID of the model to delete. Parent resource name of the model is also accepted. + project: The project. + location: The location. + Returns + None. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Get the model with the ID 'model_id'. The parent_name of delete method can be also + # 'projects//locations//models/' + model = aiplatform.Model(model_name=model_id) + + # Delete the model. + model.delete() + + +# [END aiplatform_model_registry_delete_model_sample] diff --git a/samples/model-builder/model_registry/delete_model_sample_test.py b/samples/model-builder/model_registry/delete_model_sample_test.py new file mode 100644 index 0000000000..837e483f17 --- /dev/null +++ b/samples/model-builder/model_registry/delete_model_sample_test.py @@ -0,0 +1,38 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import delete_model_sample + +import test_constants as constants + + +def test_delete_model_sample(mock_sdk_init, mock_init_model, mock_model): + # Delete a model. + delete_model_sample.delete_model_sample( + model_id=constants.MODEL_NAME, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check model initialization. + mock_init_model.assert_called_with(model_name=constants.MODEL_NAME) + + # Check that the model was deleted. + mock_model.delete.assert_called_once() diff --git a/samples/model-builder/model_registry/delete_model_version_sample.py b/samples/model-builder/model_registry/delete_model_version_sample.py new file mode 100644 index 0000000000..1cd857578a --- /dev/null +++ b/samples/model-builder/model_registry/delete_model_version_sample.py @@ -0,0 +1,44 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_delete_model_version_sample] + +from google.cloud import aiplatform + + +def delete_model_version_sample( + model_id: str, version_id: str, project: str, location: str +): + """ + Delete a Model version. + Args: + model_id: The ID of the model to delete. Parent resource name of the model is also accepted. + version_id: The version ID or version alias of the model to delete. + project: The project. + location: The location. + Returns + None. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model Registry resource with the ID 'model_id'.The parent_name of create method can be also + # 'projects//locations//models/' + model_registry = aiplatform.models.ModelRegistry(model=model_id) + + # Delete the model version with the version 'version'. + model_registry.delete_version(version=version_id) + + +# [END aiplatform_model_registry_delete_model_version_sample] diff --git a/samples/model-builder/model_registry/delete_model_version_sample_test.py b/samples/model-builder/model_registry/delete_model_version_sample_test.py new file mode 100644 index 0000000000..5b124befe2 --- /dev/null +++ b/samples/model-builder/model_registry/delete_model_version_sample_test.py @@ -0,0 +1,41 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import delete_model_version_sample + +import test_constants as constants + + +def test_delete_model_version_sample( + mock_sdk_init, mock_init_model_registry, mock_model_registry +): + # Delete a model. + delete_model_version_sample.delete_model_version_sample( + model_id=constants.MODEL_NAME, + version_id=constants.VERSION_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check model initialization. + mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME) + + # Check that the model version was deleted. + mock_model_registry.delete_version.assert_called_with(version=constants.VERSION_ID) diff --git a/samples/model-builder/model_registry/get_model_version_info_sample.py b/samples/model-builder/model_registry/get_model_version_info_sample.py new file mode 100644 index 0000000000..00c8e0ae17 --- /dev/null +++ b/samples/model-builder/model_registry/get_model_version_info_sample.py @@ -0,0 +1,47 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_get_model_version_info_sample] + +from google.cloud import aiplatform + + +def get_model_version_info_sample( + model_id: str, version_id: str, project: str, location: str +): + """ + Get model version info. + Args: + model_id: The ID of the model. + version_id: The version ID of the model version. + project: The project name. + location: The location name. + Returns: + VersionInfo resource. + """ + + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model Registry resource with the ID 'model_id'.The parent_name of create method can be also + # 'projects//locations//models/' + model_registry = aiplatform.models.ModelRegistry(model=model_id) + + # Get model version info with the version 'version_id'. + model_version_info = model_registry.get_version_info(version=version_id) + + return model_version_info + + +# [END aiplatform_model_registry_get_model_version_info_sample] diff --git a/samples/model-builder/model_registry/get_model_version_info_sample_test.py b/samples/model-builder/model_registry/get_model_version_info_sample_test.py new file mode 100644 index 0000000000..f1cd0a578f --- /dev/null +++ b/samples/model-builder/model_registry/get_model_version_info_sample_test.py @@ -0,0 +1,42 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import get_model_version_info_sample + +import test_constants as constants + + +def test_get_model_version_info_sample( + mock_sdk_init, mock_init_model_registry, mock_get_model_version_info +): + + # Get model version information. + get_model_version_info_sample.get_model_version_info_sample( + model_id=constants.MODEL_NAME, + version_id=constants.VERSION_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check model registry initialization. + mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME) + + # Check that the model version information was retrieved. + mock_get_model_version_info.assert_called_with(version=constants.VERSION_ID) diff --git a/samples/model-builder/model_registry/get_registered_model_version_sample.py b/samples/model-builder/model_registry/get_registered_model_version_sample.py new file mode 100644 index 0000000000..cda5fcdd63 --- /dev/null +++ b/samples/model-builder/model_registry/get_registered_model_version_sample.py @@ -0,0 +1,48 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_get_registered_model_version_sample] + +from typing import Optional + +from google.cloud import aiplatform + + +def get_registered_model_version_sample( + model_id: str, project: str, location: str, version_id: Optional[str] = None +): + """ + Get a registered model version. + Args: + model_id: The ID of the model. Parent resource name of the model is also accepted. + project: The project. + location: The location. + version_id: The version ID of the model. + Returns: + Model resource. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model Registry resource with the ID 'model_id'. The parent_name of get method can be also + # 'projects//locations//models/' + model_registry = aiplatform.models.ModelRegistry(model=model_id) + + # Get the registered model with version 'version_id'. + registered_model_version = model_registry.get_model(version=version_id) + + return registered_model_version + + +# [END aiplatform_model_registry_get_registered_model_version_sample] diff --git a/samples/model-builder/model_registry/get_registered_model_version_sample_test.py b/samples/model-builder/model_registry/get_registered_model_version_sample_test.py new file mode 100644 index 0000000000..0b5adc7f6d --- /dev/null +++ b/samples/model-builder/model_registry/get_registered_model_version_sample_test.py @@ -0,0 +1,41 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import get_registered_model_version_sample + +import test_constants as constants + + +def test_get_registered_model_version_sample( + mock_sdk_init, mock_init_model_registry, mock_get_model +): + # Get the registered model version. + get_registered_model_version_sample.get_registered_model_version_sample( + model_id=constants.MODEL_NAME, + project=constants.PROJECT, + location=constants.LOCATION, + version_id=constants.VERSION_ID, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check that the model registry was created. + mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME) + + # Check that the model version was retrieved. + mock_get_model.assert_called_with(version=constants.VERSION_ID) diff --git a/samples/model-builder/model_registry/list_model_versions_with_model_registry_sample.py b/samples/model-builder/model_registry/list_model_versions_with_model_registry_sample.py new file mode 100644 index 0000000000..dd4938ec7f --- /dev/null +++ b/samples/model-builder/model_registry/list_model_versions_with_model_registry_sample.py @@ -0,0 +1,43 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_list_model_versions_with_model_registry_sample] + +from google.cloud import aiplatform + + +def list_model_versions_sample(model_id: str, project: str, location: str): + """ + List all model versions of a model. + Args: + model_id: The ID of the model to list. Parent resource name of the model is also accepted. + project: The project. + location: The location. + Returns: + versions: List of model versions. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Initialize the Model Registry resource with the ID 'model_id'.The parent_name of create method can be also + # 'projects//locations//models/' + model_registry = aiplatform.models.ModelRegistry(model=model_id) + + # List all model versions of the model. + versions = model_registry.list_versions() + + return versions + + +# [END aiplatform_model_registry_list_model_versions_with_model_registry_sample] diff --git a/samples/model-builder/model_registry/list_model_versions_with_model_registry_sample_test.py b/samples/model-builder/model_registry/list_model_versions_with_model_registry_sample_test.py new file mode 100644 index 0000000000..6b741be87f --- /dev/null +++ b/samples/model-builder/model_registry/list_model_versions_with_model_registry_sample_test.py @@ -0,0 +1,45 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import list_model_versions_with_model_registry_sample + +import test_constants as constants + + +def test_list_model_versions_sample( + mock_sdk_init, mock_init_model_registry, mock_list_versions, mock_version_info +): + versions = ( + list_model_versions_with_model_registry_sample.list_model_versions_sample( + model_id=constants.MODEL_NAME, + project=constants.PROJECT, + location=constants.LOCATION, + ) + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check model registry initialization. + mock_init_model_registry.assert_called_with(model=constants.MODEL_NAME) + + # Check model versions. + assert len(versions) == 2 + + # Check model version info. + assert versions[0] is mock_version_info + assert versions[1] is mock_version_info diff --git a/samples/model-builder/model_registry/upload_new_aliased_model_version_sample.py b/samples/model-builder/model_registry/upload_new_aliased_model_version_sample.py new file mode 100644 index 0000000000..3a6d4d9711 --- /dev/null +++ b/samples/model-builder/model_registry/upload_new_aliased_model_version_sample.py @@ -0,0 +1,63 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_upload_new_aliased_model_version_sample] + +from typing import List + +from google.cloud import aiplatform + + +def upload_new_aliased_model_version_sample( + parent_name: str, + artifact_uri: str, + serving_container_image: str, + is_default_version: bool, + version_aliases: List[str], + version_description: str, + project: str, + location: str, +): + """ + Uploads a new aliased version of a model with ID 'model_id'. + Args: + parent_name: The parent resource name of an existing model. + artifact_uri: The URI of the model artifact to upload. + serving_container_image: The name of the serving container image to use. + is_default_version: Whether this version is the default version of the model. + version_aliases: The aliases of the model version. + version_description: The description of the model version. + project: The project. + location: The location. + Returns: + The new version of the model. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Upload a new aliased version of the Model resource with the ID 'model_id'. The parent_name of upload method can + # be also 'projects//locations//models/' + model = aiplatform.Model.upload( + artifact_uri=artifact_uri, + serving_container_image=serving_container_image, + parent_name=parent_name, + is_default_version=is_default_version, + version_aliases=version_aliases, + version_description=version_description, + ) + + return model + + +# [END aiplatform_model_registry_upload_new_aliased_model_version_sample] diff --git a/samples/model-builder/model_registry/upload_new_aliased_model_version_sample_test.py b/samples/model-builder/model_registry/upload_new_aliased_model_version_sample_test.py new file mode 100644 index 0000000000..e53082e7d7 --- /dev/null +++ b/samples/model-builder/model_registry/upload_new_aliased_model_version_sample_test.py @@ -0,0 +1,47 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import upload_new_aliased_model_version_sample + +import test_constants as constants + + +def test_upload_new_model_version_sample(mock_sdk_init, mock_upload_model): + # Upload a new version of the model. + upload_new_aliased_model_version_sample.upload_new_aliased_model_version_sample( + parent_name=constants.MODEL_NAME, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_image=constants.SERVING_CONTAINER_IMAGE, + is_default_version=constants.IS_DEFAULT_VERSION, + version_aliases=constants.VERSION_ALIASES, + version_description=constants.VERSION_DESCRIPTION, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check client initialization. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check that the model was uploaded. + mock_upload_model.assert_called_with( + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_image=constants.SERVING_CONTAINER_IMAGE, + is_default_version=constants.IS_DEFAULT_VERSION, + version_aliases=constants.VERSION_ALIASES, + version_description=constants.VERSION_DESCRIPTION, + parent_name=constants.MODEL_NAME, + ) diff --git a/samples/model-builder/model_registry/upload_new_default_model_version_sample.py b/samples/model-builder/model_registry/upload_new_default_model_version_sample.py new file mode 100644 index 0000000000..31672b58ba --- /dev/null +++ b/samples/model-builder/model_registry/upload_new_default_model_version_sample.py @@ -0,0 +1,54 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_upload_new_default_model_version_sample] + +from google.cloud import aiplatform + + +def upload_new_default_model_version_sample( + parent_name: str, + artifact_uri: str, + serving_container_image: str, + project: str, + location: str, +): + """ + Uploads a new default version of a model with ID 'model_id'. + Args: + parent_name: The parent resource name of the existing model. + artifact_uri: The URI of the model artifact to upload. + serving_container_image: The name of the serving container image to use. + project: The project. + location: The location. + + Returns: + The new version of the model. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Upload a new default version of the Model resource with the ID 'model_id'. + # The parent_name of upload method can be also + # 'projects//locations//models/' + model = aiplatform.Model.upload( + artifact_uri=artifact_uri, + serving_container_image=serving_container_image, + parent_name=parent_name, + ) + + return model + + +# [END aiplatform_model_registry_upload_new_default_model_version_sample] diff --git a/samples/model-builder/model_registry/upload_new_default_model_version_sample_test.py b/samples/model-builder/model_registry/upload_new_default_model_version_sample_test.py new file mode 100644 index 0000000000..06e092331f --- /dev/null +++ b/samples/model-builder/model_registry/upload_new_default_model_version_sample_test.py @@ -0,0 +1,41 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import upload_new_default_model_version_sample + +import test_constants as constants + + +def test_upload_new_model_version_sample(mock_sdk_init, mock_upload_model): + # Upload a new version of the model. + upload_new_default_model_version_sample.upload_new_default_model_version_sample( + parent_name=constants.MODEL_NAME, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_image=constants.SERVING_CONTAINER_IMAGE, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Initialize the client. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check that the model was uploaded. + mock_upload_model.assert_called_with( + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_image=constants.SERVING_CONTAINER_IMAGE, + parent_name=constants.MODEL_NAME, + ) diff --git a/samples/model-builder/model_registry/upload_new_model_version_using_custom_training_pipeline_sample.py b/samples/model-builder/model_registry/upload_new_model_version_using_custom_training_pipeline_sample.py new file mode 100644 index 0000000000..a5c4d28447 --- /dev/null +++ b/samples/model-builder/model_registry/upload_new_model_version_using_custom_training_pipeline_sample.py @@ -0,0 +1,95 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_model_registry_upload_new_model_version_using_custom_training_pipeline] + +from typing import List + +from google.cloud import aiplatform + + +def upload_new_model_version_using_custom_training_pipeline( + display_name: str, + script_path: str, + container_uri, + model_serving_container_image_uri: str, + dataset_id: str, + replica_count: int, + machine_type: str, + accelerator_type: str, + accelerator_count: int, + parent_model: str, + args: List[str], + model_version_aliases: List[str], + model_version_description: str, + is_default_version: bool, + project: str, + location: str, +): + """ + Uploads a new model version using a custom training pipeline. + Args: + display_name: The display name of the model version. + script_path: The path to the Python script that trains the model. + container_uri: The URI of the container to use for training. + model_serving_container_image_uri: The URI of the serving container image to use. + dataset_id: The ID of the dataset to use for training. + replica_count: The number of replicas to use for training. + machine_type: The machine type to use for training. + accelerator_type: The accelerator type to use for training. + accelerator_count: The number of accelerators to use for training. + parent_model: The parent resource name of an existing model. + args: A list of arguments to pass to the training script. + model_version_aliases: The aliases of the model version to create. + model_version_description: The description of the model version. + is_default_version: Whether the model version is the default version. + project: The project. + location: The location. + Returns: + The new version of the model. + """ + # Initialize the client. + aiplatform.init(project=project, location=location) + + # Create the training job. + # This job will upload a new, non-default version of the my-training-job model + job = aiplatform.CustomTrainingJob( + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + model_serving_container_image_uri=model_serving_container_image_uri, + ) + + # Create dataset + # This examples uses a TabularDataset, but you can use any dataset type. + dataset = aiplatform.TabularDataset(dataset_id) if dataset_id else None + + # Run the training job. + model = job.run( + dataset=dataset, + args=args, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + parent_model=parent_model, + model_version_aliases=model_version_aliases, + model_version_description=model_version_description, + is_default_version=is_default_version, + ) + + return model + + +# [END aiplatform_model_registry_upload_new_model_version_using_custom_training_pipeline] diff --git a/samples/model-builder/model_registry/upload_new_model_version_using_custom_training_pipeline_test.py b/samples/model-builder/model_registry/upload_new_model_version_using_custom_training_pipeline_test.py new file mode 100644 index 0000000000..39882c1cc9 --- /dev/null +++ b/samples/model-builder/model_registry/upload_new_model_version_using_custom_training_pipeline_test.py @@ -0,0 +1,74 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import upload_new_model_version_using_custom_training_pipeline_sample + +import test_constants as constants + + +def test_upload_new_model_version_using_custom_training_pipeline_sample( + mock_sdk_init, + mock_tabular_dataset, + mock_get_tabular_dataset, + mock_get_custom_training_job, + mock_run_custom_training_job, + mock_upload_model, +): + + upload_new_model_version_using_custom_training_pipeline_sample.upload_new_model_version_using_custom_training_pipeline( + display_name=constants.DISPLAY_NAME, + script_path=constants.SCRIPT_PATH, + container_uri=constants.CONTAINER_URI, + model_serving_container_image_uri=constants.CONTAINER_URI, + dataset_id=constants.RESOURCE_ID, + replica_count=constants.REPLICA_COUNT, + machine_type=constants.MACHINE_TYPE, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + parent_model=constants.MODEL_NAME, + args=constants.ARGS, + model_version_aliases=constants.VERSION_ALIASES, + model_version_description=constants.MODEL_DESCRIPTION, + is_default_version=constants.IS_DEFAULT_VERSION, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + # Check if the client was initialized. + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + # Check if the training job was created. + mock_get_custom_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + script_path=constants.SCRIPT_PATH, + container_uri=constants.CONTAINER_URI, + model_serving_container_image_uri=constants.CONTAINER_URI, + ) + + # Check if the training job was run. + mock_run_custom_training_job.assert_called_once_with( + dataset=mock_tabular_dataset, + args=constants.ARGS, + replica_count=constants.REPLICA_COUNT, + machine_type=constants.MACHINE_TYPE, + accelerator_type=constants.ACCELERATOR_TYPE, + accelerator_count=constants.ACCELERATOR_COUNT, + parent_model=constants.MODEL_NAME, + model_version_aliases=constants.VERSION_ALIASES, + model_version_description=constants.MODEL_DESCRIPTION, + is_default_version=constants.IS_DEFAULT_VERSION, + ) diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 01f8f6080d..9e8cbae6eb 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -41,6 +41,7 @@ DATASET_NAME = f"{PARENT}/datasets/{RESOURCE_ID}" ENDPOINT_NAME = f"{PARENT}/endpoints/{RESOURCE_ID}" MODEL_NAME = f"{PARENT}/models/{RESOURCE_ID}" +VERSION_ID = "test-version" TRAINING_JOB_NAME = f"{PARENT}/trainingJobs/{RESOURCE_ID}" BIGQUERY_SOURCE = f"bq://{PROJECT}.{DATASET_NAME}.table1" @@ -276,3 +277,8 @@ STEP = 1 TIMESTAMP = timestamp_pb2.Timestamp() + +VERSION_ID = "test-version" +IS_DEFAULT_VERSION = False +VERSION_ALIASES = ["test-version-alias"] +VERSION_DESCRIPTION = "test-version-description"