Skip to content

Commit

Permalink
Added tabular explanation sample
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Apr 30, 2021
1 parent ef4f6f8 commit 00102be
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 7 deletions.
10 changes: 7 additions & 3 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,11 @@ def mock_batch_predict_model(mock_model):
with patch.object(mock_model, "batch_predict") as mock:
yield mock


@pytest.fixture
def mock_upload_model():
with patch.object(aiplatform.models.Model, "upload") as mock:
with patch.object(aiplatform.Model, "upload") as mock:
yield mock


@pytest.fixture
def mock_deploy_model(mock_model, mock_endpoint):
with patch.object(mock_model, "deploy") as mock:
Expand Down Expand Up @@ -282,3 +280,9 @@ def mock_get_endpoint(mock_endpoint):
with patch.object(aiplatform, "Endpoint") as mock_get_endpoint:
mock_get_endpoint.return_value = mock_endpoint
yield mock_get_endpoint

@pytest.fixture
def mock_endpoint_explain(mock_endpoint):
with patch.object(mock_endpoint, "explain") as mock_endpoint_explain:
mock_get_endpoint.return_value = mock_endpoint
yield mock_endpoint_explain
32 changes: 32 additions & 0 deletions samples/model-builder/explain_tabular_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud import aiplatform
from typing import Dict

# [START aiplatform_sdk_explain_tabular_sample]
def explain_tabular_sample(project: str, location: str, endpoint_id: str, instance_dict: Dict):

aiplatform.init(project=project, location=location)

endpoint = aiplatform.Endpoint(endpoint_id)

response = endpoint.explain(instances=[instance_dict], parameters={})

for prediction_ in response.predictions:
print(prediction_)


# [END aiplatform_sdk_explain_tabular_sample]
40 changes: 40 additions & 0 deletions samples/model-builder/explain_tabular_sample_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import explain_tabular_sample
import test_constants as constants


def test_explain_tabular_sample(mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain):

explain_tabular_sample.explain_tabular_sample(
project=constants.PROJECT,
location=constants.LOCATION,
endpoint_id=constants.ENDPOINT_NAME,
instance_dict=constants.PREDICTION_TABULAR_INSTANCE,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)

mock_get_endpoint.assert_called_once_with(
constants.ENDPOINT_NAME,
)

mock_endpoint_explain.assert_called_once_with(
instances=[constants.PREDICTION_TABULAR_INSTANCE],
parameters={}
)
36 changes: 32 additions & 4 deletions samples/model-builder/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,31 @@
inputs={
"features": {
"input_tensor_name": "dense_input",
"encoding": "BAG_OF_FEATURES",
# Input is tabular data
"modality": "numeric",
"index_feature_mapping": ["abc", "def", "ghj"],
# Assign feature names to the inputs for explanation
"encoding": "BAG_OF_FEATURES",
"index_feature_mapping": [
"crim",
"zn",
"indus",
"chas",
"nox",
"rm",
"age",
"dis",
"rad",
"tax",
"ptratio",
"b",
"lstat",
],
}
},
outputs={"medv": {"output_tensor_name": "dense_2"}},
outputs={"prediction": {"output_tensor_name": "dense_2"}},
)
EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters(
{"sampled_shapley_attribution": {"path_count": 10}}
{"xrai_attribution": {"step_count": 1}}
)

# Endpoint constants
Expand All @@ -148,4 +164,16 @@
TRAFFIC_SPLIT = {"a": 99, "b": 1}
MIN_REPLICA_COUNT = 1
MAX_REPLICA_COUNT = 1
ACCELERATOR_TYPE = "NVIDIA_TESLA_P100"
ACCELERATOR_COUNT = 2
ENDPOINT_DEPLOY_METADATA = ()
PREDICTION_TABULAR_INSTANCE = {
"longitude": "-124.35",
"latitude": "40.54",
"housing_median_age": "52.0",
"total_rooms": "1820.0",
"total_bedrooms": "300.0",
"population": "806",
"households": "270.0",
"median_income": "3.014700",
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud import aiplatform
from typing import Optional, Sequence, Tuple, Dict
from google.cloud.aiplatform import explain

# [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample]
def upload_model_explain_tabular_managed_container_sample(
project,
location,
model_display_name: str,
serving_container_image_uri: str,
artifact_uri: Optional[str] = None,
serving_container_predict_route: Optional[str] = None,
serving_container_health_route: Optional[str] = None,
description: Optional[str] = None,
serving_container_command: Optional[Sequence[str]] = None,
serving_container_args: Optional[Sequence[str]] = None,
serving_container_environment_variables: Optional[Dict[str, str]] = None,
serving_container_ports: Optional[Sequence[int]] = None,
instance_schema_uri: Optional[str] = None,
parameters_schema_uri: Optional[str] = None,
prediction_schema_uri: Optional[str] = None,
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
explanation_parameters: Optional[explain.ExplanationParameters] = None,
sync: bool = True,
):

aiplatform.init(project=project, location=location)

model = aiplatform.Model.upload(
display_name=model_display_name,
serving_container_image_uri=serving_container_image_uri,
artifact_uri=artifact_uri,
serving_container_predict_route=serving_container_predict_route,
serving_container_health_route=serving_container_health_route,
description=description,
serving_container_command=serving_container_command,
serving_container_args=serving_container_args,
serving_container_environment_variables=serving_container_environment_variables,
serving_container_ports=serving_container_ports,
instance_schema_uri=instance_schema_uri,
parameters_schema_uri=parameters_schema_uri,
prediction_schema_uri=prediction_schema_uri,
explanation_metadata=explanation_metadata,
explanation_parameters=explanation_parameters,
sync=sync,
)

model.wait()

print(model.display_name)
print(model.resource_name)
return model


# [END aiplatform_sdk_upload_model_explain_tabular_managed_container_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import upload_model_explain_tabular_managed_container_sample
import test_constants as constants


def test_upload_model_explain_tabular_managed_container_sample(mock_sdk_init, mock_model, mock_init_model, mock_upload_model):

upload_model_explain_tabular_managed_container_sample.upload_model_explain_tabular_managed_container_sample(
project=constants.PROJECT,
location=constants.LOCATION,
model_display_name=constants.MODEL_NAME,
serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI,
artifact_uri=constants.MODEL_ARTIFACT_URI,
serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE,
serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE,
description=constants.DESCRIPTION,
serving_container_command=constants.SERVING_CONTAINER_COMMAND,
serving_container_args=constants.SERVING_CONTAINER_ARGS,
serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
serving_container_ports=constants.SERVING_CONTAINER_PORTS,
instance_schema_uri=constants.INSTANCE_SCHEMA_URI,
parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI,
prediction_schema_uri=constants.PREDICTION_SCHEMA_URI,
explanation_metadata=constants.EXPLANATION_METADATA,
explanation_parameters=constants.EXPLANATION_PARAMETERS,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)

mock_upload_model.assert_called_once_with(
display_name=constants.MODEL_NAME,
serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI,
artifact_uri=constants.MODEL_ARTIFACT_URI,
serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE,
serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE,
description=constants.DESCRIPTION,
serving_container_command=constants.SERVING_CONTAINER_COMMAND,
serving_container_args=constants.SERVING_CONTAINER_ARGS,
serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
serving_container_ports=constants.SERVING_CONTAINER_PORTS,
instance_schema_uri=constants.INSTANCE_SCHEMA_URI,
parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI,
prediction_schema_uri=constants.PREDICTION_SCHEMA_URI,
explanation_metadata=constants.EXPLANATION_METADATA,
explanation_parameters=constants.EXPLANATION_PARAMETERS,
sync=True,
)

0 comments on commit 00102be

Please sign in to comment.