From f7b2d9fa09370cae379d6804fe7a650335dc4b01 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Wed, 21 Apr 2021 17:54:15 -0400 Subject: [PATCH 1/7] Added tabular explanation sample --- samples/model-builder/conftest.py | 10 ++- .../model-builder/explain_tabular_sample.py | 32 +++++++++ .../explain_tabular_sample_test.py | 40 +++++++++++ samples/model-builder/test_constants.py | 36 ++++++++-- ...xplain_tabular_managed_container_sample.py | 70 +++++++++++++++++++ ...n_tabular_managed_container_sample_test.py | 63 +++++++++++++++++ 6 files changed, 244 insertions(+), 7 deletions(-) create mode 100644 samples/model-builder/explain_tabular_sample.py create mode 100644 samples/model-builder/explain_tabular_sample_test.py create mode 100644 samples/model-builder/upload_model_explain_tabular_managed_container_sample.py create mode 100644 samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 112d5c200b..a1c2222d70 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -235,13 +235,11 @@ def mock_batch_predict_model(mock_model): with patch.object(mock_model, "batch_predict") as mock: yield mock - @pytest.fixture def mock_upload_model(): - with patch.object(aiplatform.models.Model, "upload") as mock: + with patch.object(aiplatform.Model, "upload") as mock: yield mock - @pytest.fixture def mock_deploy_model(mock_model, mock_endpoint): with patch.object(mock_model, "deploy") as mock: @@ -286,3 +284,9 @@ def mock_get_endpoint(mock_endpoint): with patch.object(aiplatform, "Endpoint") as mock_get_endpoint: mock_get_endpoint.return_value = mock_endpoint yield mock_get_endpoint + +@pytest.fixture +def mock_endpoint_explain(mock_endpoint): + with patch.object(mock_endpoint, "explain") as mock_endpoint_explain: + mock_get_endpoint.return_value = mock_endpoint + yield mock_endpoint_explain \ No newline at end of file diff --git a/samples/model-builder/explain_tabular_sample.py b/samples/model-builder/explain_tabular_sample.py new file mode 100644 index 0000000000..9206ee13c6 --- /dev/null +++ b/samples/model-builder/explain_tabular_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform +from typing import Dict + +# [START aiplatform_sdk_explain_tabular_sample] +def explain_tabular_sample(project: str, location: str, endpoint_id: str, instance_dict: Dict): + + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint_id) + + response = endpoint.explain(instances=[instance_dict], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_explain_tabular_sample] diff --git a/samples/model-builder/explain_tabular_sample_test.py b/samples/model-builder/explain_tabular_sample_test.py new file mode 100644 index 0000000000..aca78ad995 --- /dev/null +++ b/samples/model-builder/explain_tabular_sample_test.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import explain_tabular_sample +import test_constants as constants + + +def test_explain_tabular_sample(mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain): + + explain_tabular_sample.explain_tabular_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint_id=constants.ENDPOINT_NAME, + instance_dict=constants.PREDICTION_TABULAR_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) + + mock_endpoint_explain.assert_called_once_with( + instances=[constants.PREDICTION_TABULAR_INSTANCE], + parameters={} + ) \ No newline at end of file diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 994a8724ee..641fa1c490 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -131,15 +131,31 @@ inputs={ "features": { "input_tensor_name": "dense_input", - "encoding": "BAG_OF_FEATURES", + # Input is tabular data "modality": "numeric", - "index_feature_mapping": ["abc", "def", "ghj"], + # Assign feature names to the inputs for explanation + "encoding": "BAG_OF_FEATURES", + "index_feature_mapping": [ + "crim", + "zn", + "indus", + "chas", + "nox", + "rm", + "age", + "dis", + "rad", + "tax", + "ptratio", + "b", + "lstat", + ], } }, - outputs={"medv": {"output_tensor_name": "dense_2"}}, + outputs={"prediction": {"output_tensor_name": "dense_2"}}, ) EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( - {"sampled_shapley_attribution": {"path_count": 10}} + {"xrai_attribution": {"step_count": 1}} ) # Endpoint constants @@ -148,4 +164,16 @@ TRAFFIC_SPLIT = {"a": 99, "b": 1} MIN_REPLICA_COUNT = 1 MAX_REPLICA_COUNT = 1 +ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +ACCELERATOR_COUNT = 2 ENDPOINT_DEPLOY_METADATA = () +PREDICTION_TABULAR_INSTANCE = { + "longitude": "-124.35", + "latitude": "40.54", + "housing_median_age": "52.0", + "total_rooms": "1820.0", + "total_bedrooms": "300.0", + "population": "806", + "households": "270.0", + "median_income": "3.014700", +} diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py new file mode 100644 index 0000000000..a1db4aac3e --- /dev/null +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py @@ -0,0 +1,70 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform +from typing import Optional, Sequence, Tuple, Dict +from google.cloud.aiplatform import explain + +# [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] +def upload_model_explain_tabular_managed_container_sample( + project, + location, + model_display_name: str, + serving_container_image_uri: str, + artifact_uri: Optional[str] = None, + serving_container_predict_route: Optional[str] = None, + serving_container_health_route: Optional[str] = None, + description: Optional[str] = None, + serving_container_command: Optional[Sequence[str]] = None, + serving_container_args: Optional[Sequence[str]] = None, + serving_container_environment_variables: Optional[Dict[str, str]] = None, + serving_container_ports: Optional[Sequence[int]] = None, + instance_schema_uri: Optional[str] = None, + parameters_schema_uri: Optional[str] = None, + prediction_schema_uri: Optional[str] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + sync: bool = True, +): + + aiplatform.init(project=project, location=location) + + model = aiplatform.Model.upload( + display_name=model_display_name, + serving_container_image_uri=serving_container_image_uri, + artifact_uri=artifact_uri, + serving_container_predict_route=serving_container_predict_route, + serving_container_health_route=serving_container_health_route, + description=description, + serving_container_command=serving_container_command, + serving_container_args=serving_container_args, + serving_container_environment_variables=serving_container_environment_variables, + serving_container_ports=serving_container_ports, + instance_schema_uri=instance_schema_uri, + parameters_schema_uri=parameters_schema_uri, + prediction_schema_uri=prediction_schema_uri, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + return model + + +# [END aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py new file mode 100644 index 0000000000..7f9253a28f --- /dev/null +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py @@ -0,0 +1,63 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import upload_model_explain_tabular_managed_container_sample +import test_constants as constants + + +def test_upload_model_explain_tabular_managed_container_sample(mock_sdk_init, mock_model, mock_init_model, mock_upload_model): + + upload_model_explain_tabular_managed_container_sample.upload_model_explain_tabular_managed_container_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_display_name=constants.MODEL_NAME, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + description=constants.DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_upload_model.assert_called_once_with( + display_name=constants.MODEL_NAME, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + description=constants.DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + sync=True, + ) From 76116c741620b16dc8d4377154f6cb20fd72b1bf Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Fri, 30 Apr 2021 19:46:55 -0400 Subject: [PATCH 2/7] Cleaned up mocks --- samples/model-builder/conftest.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index a1c2222d70..337d25a2ef 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -219,7 +219,7 @@ def mock_run_custom_training_job(): @pytest.fixture def mock_model(): - mock = MagicMock(aiplatform.models.Model) + mock = MagicMock(aiplatform.Model) yield mock @@ -235,11 +235,13 @@ def mock_batch_predict_model(mock_model): with patch.object(mock_model, "batch_predict") as mock: yield mock + @pytest.fixture def mock_upload_model(): with patch.object(aiplatform.Model, "upload") as mock: yield mock + @pytest.fixture def mock_deploy_model(mock_model, mock_endpoint): with patch.object(mock_model, "deploy") as mock: @@ -269,7 +271,7 @@ def mock_create_batch_prediction_job(): @pytest.fixture def mock_endpoint(): - mock = MagicMock(aiplatform.models.Endpoint) + mock = MagicMock(aiplatform.Endpoint) yield mock @@ -285,8 +287,9 @@ def mock_get_endpoint(mock_endpoint): mock_get_endpoint.return_value = mock_endpoint yield mock_get_endpoint + @pytest.fixture def mock_endpoint_explain(mock_endpoint): with patch.object(mock_endpoint, "explain") as mock_endpoint_explain: mock_get_endpoint.return_value = mock_endpoint - yield mock_endpoint_explain \ No newline at end of file + yield mock_endpoint_explain From 8fb4947ecde4e0283dbdf9b226f3ec5241d23910 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Mon, 3 May 2021 17:44:08 -0400 Subject: [PATCH 3/7] Ran linter --- .../create_and_import_dataset_video_sample_test.py | 1 - samples/model-builder/explain_tabular_sample.py | 6 +++++- .../model-builder/explain_tabular_sample_test.py | 13 ++++++------- .../import_data_video_classification_sample_test.py | 1 - ...odel_explain_tabular_managed_container_sample.py | 9 +++++---- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/samples/model-builder/create_and_import_dataset_video_sample_test.py b/samples/model-builder/create_and_import_dataset_video_sample_test.py index 1ebbc7a3d0..e1d1ddeb19 100644 --- a/samples/model-builder/create_and_import_dataset_video_sample_test.py +++ b/samples/model-builder/create_and_import_dataset_video_sample_test.py @@ -16,7 +16,6 @@ from google.cloud.aiplatform import schema import create_and_import_dataset_video_sample - import test_constants as constants diff --git a/samples/model-builder/explain_tabular_sample.py b/samples/model-builder/explain_tabular_sample.py index 9206ee13c6..096eda15f9 100644 --- a/samples/model-builder/explain_tabular_sample.py +++ b/samples/model-builder/explain_tabular_sample.py @@ -14,10 +14,14 @@ from google.cloud import aiplatform + from typing import Dict + # [START aiplatform_sdk_explain_tabular_sample] -def explain_tabular_sample(project: str, location: str, endpoint_id: str, instance_dict: Dict): +def explain_tabular_sample( + project: str, location: str, endpoint_id: str, instance_dict: Dict +): aiplatform.init(project=project, location=location) diff --git a/samples/model-builder/explain_tabular_sample_test.py b/samples/model-builder/explain_tabular_sample_test.py index aca78ad995..d088da9658 100644 --- a/samples/model-builder/explain_tabular_sample_test.py +++ b/samples/model-builder/explain_tabular_sample_test.py @@ -17,7 +17,9 @@ import test_constants as constants -def test_explain_tabular_sample(mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain): +def test_explain_tabular_sample( + mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain +): explain_tabular_sample.explain_tabular_sample( project=constants.PROJECT, @@ -30,11 +32,8 @@ def test_explain_tabular_sample(mock_sdk_init, mock_endpoint, mock_get_endpoint, project=constants.PROJECT, location=constants.LOCATION ) - mock_get_endpoint.assert_called_once_with( - constants.ENDPOINT_NAME, - ) + mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) mock_endpoint_explain.assert_called_once_with( - instances=[constants.PREDICTION_TABULAR_INSTANCE], - parameters={} - ) \ No newline at end of file + instances=[constants.PREDICTION_TABULAR_INSTANCE], parameters={} + ) diff --git a/samples/model-builder/import_data_video_classification_sample_test.py b/samples/model-builder/import_data_video_classification_sample_test.py index cce5c0abd6..5e5e142533 100644 --- a/samples/model-builder/import_data_video_classification_sample_test.py +++ b/samples/model-builder/import_data_video_classification_sample_test.py @@ -17,7 +17,6 @@ import pytest import import_data_video_classification_sample - import test_constants as constants diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py index a1db4aac3e..157db87d52 100644 --- a/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py @@ -14,8 +14,9 @@ from google.cloud import aiplatform -from typing import Optional, Sequence, Tuple, Dict -from google.cloud.aiplatform import explain + +from typing import Optional, Sequence, Dict + # [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] def upload_model_explain_tabular_managed_container_sample( @@ -34,8 +35,8 @@ def upload_model_explain_tabular_managed_container_sample( instance_schema_uri: Optional[str] = None, parameters_schema_uri: Optional[str] = None, prediction_schema_uri: Optional[str] = None, - explanation_metadata: Optional[explain.ExplanationMetadata] = None, - explanation_parameters: Optional[explain.ExplanationParameters] = None, + explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None, + explanation_parameters: Optional[aiplatform.explain.ExplanationParameters] = None, sync: bool = True, ): From 70ed28c1dac482d8de01525b26e6e8ca779d69a1 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Tue, 4 May 2021 15:01:53 -0400 Subject: [PATCH 4/7] Fixed mock and added explanation printing --- samples/model-builder/conftest.py | 5 +++-- samples/model-builder/explain_tabular_sample.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 337d25a2ef..192878c0d8 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -219,7 +219,7 @@ def mock_run_custom_training_job(): @pytest.fixture def mock_model(): - mock = MagicMock(aiplatform.Model) + mock = MagicMock(aiplatform.models.Model) yield mock @@ -237,8 +237,9 @@ def mock_batch_predict_model(mock_model): @pytest.fixture -def mock_upload_model(): +def mock_upload_model(mock_model): with patch.object(aiplatform.Model, "upload") as mock: + mock.return_value = mock_model yield mock diff --git a/samples/model-builder/explain_tabular_sample.py b/samples/model-builder/explain_tabular_sample.py index 096eda15f9..b10f191f5c 100644 --- a/samples/model-builder/explain_tabular_sample.py +++ b/samples/model-builder/explain_tabular_sample.py @@ -29,8 +29,11 @@ def explain_tabular_sample( response = endpoint.explain(instances=[instance_dict], parameters={}) - for prediction_ in response.predictions: - print(prediction_) + for prediction in response.predictions: + print(prediction) + + for explanation in response.explanations: + print(explanation) # [END aiplatform_sdk_explain_tabular_sample] From 8200fcedecf07e4504057297c4b19cb3fd0b4c77 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Tue, 4 May 2021 15:07:05 -0400 Subject: [PATCH 5/7] Added more verbose explanations --- .../model-builder/explain_tabular_sample.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/samples/model-builder/explain_tabular_sample.py b/samples/model-builder/explain_tabular_sample.py index b10f191f5c..f28f085fd0 100644 --- a/samples/model-builder/explain_tabular_sample.py +++ b/samples/model-builder/explain_tabular_sample.py @@ -29,11 +29,23 @@ def explain_tabular_sample( response = endpoint.explain(instances=[instance_dict], parameters={}) + for explanation in response.explanations: + print(" explanation") + # Feature attributions. + attributions = explanation.attributions + for attribution in attributions: + print(" attribution") + print(" baseline_output_value:", attribution.baseline_output_value) + print(" instance_output_value:", attribution.instance_output_value) + print(" output_display_name:", attribution.output_display_name) + print(" approximation_error:", attribution.approximation_error) + print(" output_name:", attribution.output_name) + output_index = attribution.output_index + for output_index in output_index: + print(" output_index:", output_index) + for prediction in response.predictions: print(prediction) - for explanation in response.explanations: - print(explanation) - # [END aiplatform_sdk_explain_tabular_sample] From e526b5f1ef0cff4779b4459538e01f3dffd41370 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Tue, 4 May 2021 15:09:04 -0400 Subject: [PATCH 6/7] Fixed endpoint fixture --- samples/model-builder/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 192878c0d8..70431c9565 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -272,13 +272,13 @@ def mock_create_batch_prediction_job(): @pytest.fixture def mock_endpoint(): - mock = MagicMock(aiplatform.Endpoint) + mock = MagicMock(aiplatform.models.Endpoint) yield mock @pytest.fixture def mock_create_endpoint(): - with patch.object(aiplatform.Endpoint, "create") as mock: + with patch.object(aiplatform.models.Endpoint, "create") as mock: yield mock From 3d7baa207f97cc5f2076fcec63a16c1705026e90 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Tue, 4 May 2021 15:15:33 -0400 Subject: [PATCH 7/7] Fixed linting issues --- samples/model-builder/explain_tabular_sample.py | 3 +-- ...upload_model_explain_tabular_managed_container_sample.py | 3 +-- ...d_model_explain_tabular_managed_container_sample_test.py | 6 ++++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/samples/model-builder/explain_tabular_sample.py b/samples/model-builder/explain_tabular_sample.py index f28f085fd0..16d1204787 100644 --- a/samples/model-builder/explain_tabular_sample.py +++ b/samples/model-builder/explain_tabular_sample.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict from google.cloud import aiplatform -from typing import Dict - # [START aiplatform_sdk_explain_tabular_sample] def explain_tabular_sample( diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py index 157db87d52..bc676ba917 100644 --- a/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional, Sequence from google.cloud import aiplatform -from typing import Optional, Sequence, Dict - # [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] def upload_model_explain_tabular_managed_container_sample( diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py index 7f9253a28f..653de93f74 100644 --- a/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import test_constants as constants import upload_model_explain_tabular_managed_container_sample -import test_constants as constants -def test_upload_model_explain_tabular_managed_container_sample(mock_sdk_init, mock_model, mock_init_model, mock_upload_model): +def test_upload_model_explain_tabular_managed_container_sample( + mock_sdk_init, mock_model, mock_init_model, mock_upload_model +): upload_model_explain_tabular_managed_container_sample.upload_model_explain_tabular_managed_container_sample( project=constants.PROJECT,