From 00102be345b03b5d47544570087f28f74b1ba092 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Wed, 21 Apr 2021 17:54:15 -0400 Subject: [PATCH] Added tabular explanation sample --- samples/model-builder/conftest.py | 10 ++- .../model-builder/explain_tabular_sample.py | 32 +++++++++ .../explain_tabular_sample_test.py | 40 +++++++++++ samples/model-builder/test_constants.py | 36 ++++++++-- ...xplain_tabular_managed_container_sample.py | 70 +++++++++++++++++++ ...n_tabular_managed_container_sample_test.py | 63 +++++++++++++++++ 6 files changed, 244 insertions(+), 7 deletions(-) create mode 100644 samples/model-builder/explain_tabular_sample.py create mode 100644 samples/model-builder/explain_tabular_sample_test.py create mode 100644 samples/model-builder/upload_model_explain_tabular_managed_container_sample.py create mode 100644 samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 718f4645ff4..00fe684de4f 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -231,13 +231,11 @@ def mock_batch_predict_model(mock_model): with patch.object(mock_model, "batch_predict") as mock: yield mock - @pytest.fixture def mock_upload_model(): - with patch.object(aiplatform.models.Model, "upload") as mock: + with patch.object(aiplatform.Model, "upload") as mock: yield mock - @pytest.fixture def mock_deploy_model(mock_model, mock_endpoint): with patch.object(mock_model, "deploy") as mock: @@ -282,3 +280,9 @@ def mock_get_endpoint(mock_endpoint): with patch.object(aiplatform, "Endpoint") as mock_get_endpoint: mock_get_endpoint.return_value = mock_endpoint yield mock_get_endpoint + +@pytest.fixture +def mock_endpoint_explain(mock_endpoint): + with patch.object(mock_endpoint, "explain") as mock_endpoint_explain: + mock_get_endpoint.return_value = mock_endpoint + yield mock_endpoint_explain \ No newline at end of file diff --git a/samples/model-builder/explain_tabular_sample.py b/samples/model-builder/explain_tabular_sample.py new file mode 100644 index 00000000000..9206ee13c66 --- /dev/null +++ b/samples/model-builder/explain_tabular_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform +from typing import Dict + +# [START aiplatform_sdk_explain_tabular_sample] +def explain_tabular_sample(project: str, location: str, endpoint_id: str, instance_dict: Dict): + + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint_id) + + response = endpoint.explain(instances=[instance_dict], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_explain_tabular_sample] diff --git a/samples/model-builder/explain_tabular_sample_test.py b/samples/model-builder/explain_tabular_sample_test.py new file mode 100644 index 00000000000..aca78ad995b --- /dev/null +++ b/samples/model-builder/explain_tabular_sample_test.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import explain_tabular_sample +import test_constants as constants + + +def test_explain_tabular_sample(mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain): + + explain_tabular_sample.explain_tabular_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint_id=constants.ENDPOINT_NAME, + instance_dict=constants.PREDICTION_TABULAR_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) + + mock_endpoint_explain.assert_called_once_with( + instances=[constants.PREDICTION_TABULAR_INSTANCE], + parameters={} + ) \ No newline at end of file diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 994a8724eef..641fa1c4903 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -131,15 +131,31 @@ inputs={ "features": { "input_tensor_name": "dense_input", - "encoding": "BAG_OF_FEATURES", + # Input is tabular data "modality": "numeric", - "index_feature_mapping": ["abc", "def", "ghj"], + # Assign feature names to the inputs for explanation + "encoding": "BAG_OF_FEATURES", + "index_feature_mapping": [ + "crim", + "zn", + "indus", + "chas", + "nox", + "rm", + "age", + "dis", + "rad", + "tax", + "ptratio", + "b", + "lstat", + ], } }, - outputs={"medv": {"output_tensor_name": "dense_2"}}, + outputs={"prediction": {"output_tensor_name": "dense_2"}}, ) EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( - {"sampled_shapley_attribution": {"path_count": 10}} + {"xrai_attribution": {"step_count": 1}} ) # Endpoint constants @@ -148,4 +164,16 @@ TRAFFIC_SPLIT = {"a": 99, "b": 1} MIN_REPLICA_COUNT = 1 MAX_REPLICA_COUNT = 1 +ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +ACCELERATOR_COUNT = 2 ENDPOINT_DEPLOY_METADATA = () +PREDICTION_TABULAR_INSTANCE = { + "longitude": "-124.35", + "latitude": "40.54", + "housing_median_age": "52.0", + "total_rooms": "1820.0", + "total_bedrooms": "300.0", + "population": "806", + "households": "270.0", + "median_income": "3.014700", +} diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py new file mode 100644 index 00000000000..a1db4aac3e9 --- /dev/null +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample.py @@ -0,0 +1,70 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform +from typing import Optional, Sequence, Tuple, Dict +from google.cloud.aiplatform import explain + +# [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] +def upload_model_explain_tabular_managed_container_sample( + project, + location, + model_display_name: str, + serving_container_image_uri: str, + artifact_uri: Optional[str] = None, + serving_container_predict_route: Optional[str] = None, + serving_container_health_route: Optional[str] = None, + description: Optional[str] = None, + serving_container_command: Optional[Sequence[str]] = None, + serving_container_args: Optional[Sequence[str]] = None, + serving_container_environment_variables: Optional[Dict[str, str]] = None, + serving_container_ports: Optional[Sequence[int]] = None, + instance_schema_uri: Optional[str] = None, + parameters_schema_uri: Optional[str] = None, + prediction_schema_uri: Optional[str] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + sync: bool = True, +): + + aiplatform.init(project=project, location=location) + + model = aiplatform.Model.upload( + display_name=model_display_name, + serving_container_image_uri=serving_container_image_uri, + artifact_uri=artifact_uri, + serving_container_predict_route=serving_container_predict_route, + serving_container_health_route=serving_container_health_route, + description=description, + serving_container_command=serving_container_command, + serving_container_args=serving_container_args, + serving_container_environment_variables=serving_container_environment_variables, + serving_container_ports=serving_container_ports, + instance_schema_uri=instance_schema_uri, + parameters_schema_uri=parameters_schema_uri, + prediction_schema_uri=prediction_schema_uri, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + return model + + +# [END aiplatform_sdk_upload_model_explain_tabular_managed_container_sample] diff --git a/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py new file mode 100644 index 00000000000..7f9253a28f9 --- /dev/null +++ b/samples/model-builder/upload_model_explain_tabular_managed_container_sample_test.py @@ -0,0 +1,63 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import upload_model_explain_tabular_managed_container_sample +import test_constants as constants + + +def test_upload_model_explain_tabular_managed_container_sample(mock_sdk_init, mock_model, mock_init_model, mock_upload_model): + + upload_model_explain_tabular_managed_container_sample.upload_model_explain_tabular_managed_container_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_display_name=constants.MODEL_NAME, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + description=constants.DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_upload_model.assert_called_once_with( + display_name=constants.MODEL_NAME, + serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI, + artifact_uri=constants.MODEL_ARTIFACT_URI, + serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE, + serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE, + description=constants.DESCRIPTION, + serving_container_command=constants.SERVING_CONTAINER_COMMAND, + serving_container_args=constants.SERVING_CONTAINER_ARGS, + serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=constants.SERVING_CONTAINER_PORTS, + instance_schema_uri=constants.INSTANCE_SCHEMA_URI, + parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI, + prediction_schema_uri=constants.PREDICTION_SCHEMA_URI, + explanation_metadata=constants.EXPLANATION_METADATA, + explanation_parameters=constants.EXPLANATION_PARAMETERS, + sync=True, + )