Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add XAI SDK integration to TensorFlow models with LIT integration #917

Merged
merged 5 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 118 additions & 40 deletions google/cloud/aiplatform/explain/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple, Union
import os
taiseiak marked this conversation as resolved.
Show resolved Hide resolved
from typing import Dict, List, Optional, Tuple, Union

try:
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import dtypes as lit_dtypes
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp import notebook
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
model: str,
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
attribution_method: str = "sampled_shapley",
):
"""Construct a VertexLitModel.
Args:
Expand All @@ -94,39 +97,33 @@ def __init__(
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
"""
self._loaded_model = tf.saved_model.load(model)
serving_default = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
_, self._kwargs_signature = serving_default.structured_input_signature
self._output_signature = serving_default.structured_outputs

if len(self._kwargs_signature) != 1:
raise ValueError("Please use a model with only one input tensor.")

if len(self._output_signature) != 1:
raise ValueError("Please use a model with only one output tensor.")

self._load_model(model)
self._input_types = input_types
self._output_types = output_types
self._input_tensor_name = next(iter(self._kwargs_signature))
self._attribution_explainer = None
if os.environ.get("LIT_PROXY_URL"):
self._set_up_attribution_explainer(model, attribution_method)

@property
def attribution_explainer(self,) -> Optional["AttributionExplainer"]: # noqa: F821
"""Gets the attribution explainer property if set."""
return self._attribution_explainer

def predict_minibatch(
self, inputs: List[lit_types.JsonDict]
) -> List[lit_types.JsonDict]:
"""Returns predictions for a single batch of examples.
Args:
inputs:
sequence of inputs, following model.input_spec()
Returns:
list of outputs, following model.output_spec()
"""
instances = []
for input in inputs:
instance = [input[feature] for feature in self._input_types]
instances.append(instance)
prediction_input_dict = {
next(iter(self._kwargs_signature)): tf.convert_to_tensor(instances)
self._input_tensor_name: tf.convert_to_tensor(instances)
}
prediction_dict = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Expand All @@ -140,6 +137,15 @@ def predict_minibatch(
for label, value in zip(self._output_types.keys(), prediction)
}
)
# Get feature attributions
if self.attribution_explainer:
attributions = self.attribution_explainer.explain(
[{self._input_tensor_name: i} for i in instances]
)
for i, attribution in enumerate(attributions):
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
attribution.feature_importance()
)
return outputs

def input_spec(self) -> lit_types.Spec:
Expand All @@ -148,7 +154,63 @@ def input_spec(self) -> lit_types.Spec:

def output_spec(self) -> lit_types.Spec:
"""Return a spec describing model outputs."""
return self._output_types
output_spec_dict = dict(self._output_types)
if self.attribution_explainer:
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
signed=True
)
return output_spec_dict

def _load_model(self, model: str):
taiseiak marked this conversation as resolved.
Show resolved Hide resolved
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
Args:
model: Required. A string reference to a TensorFlow saved model directory.
Raises:
ValueError if the model has more than one input tensor or more than one output tensor.
"""
self._loaded_model = tf.saved_model.load(model)
serving_default = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
_, self._kwargs_signature = serving_default.structured_input_signature
self._output_signature = serving_default.structured_outputs

if len(self._kwargs_signature) != 1:
raise ValueError("Please use a model with only one input tensor.")

if len(self._output_signature) != 1:
raise ValueError("Please use a model with only one output tensor.")

def _set_up_attribution_explainer(
taiseiak marked this conversation as resolved.
Show resolved Hide resolved
self, model: str, attribution_method: str = "integrated_gradients"
):
"""Populates the attribution explainer attribute of the class."""
taiseiak marked this conversation as resolved.
Show resolved Hide resolved
try:
import explainable_ai_sdk
from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
except ImportError:
print(
"Skipping explanations because the Explainable AI SDK is not installed."
taiseiak marked this conversation as resolved.
Show resolved Hide resolved
'Please install the SDK using "pip install explainable-ai-sdk"'
)
return
taiseiak marked this conversation as resolved.
Show resolved Hide resolved

builder = SavedModelMetadataBuilder(model)
builder.get_metadata()
builder.set_numeric_metadata(
self._input_tensor_name,
index_feature_mapping=list(self._input_types.keys()),
)
builder.save_metadata(model)
if attribution_method == "integrated_gradients":
explainer_config = explainable_ai_sdk.IntegratedGradientsConfig()
else:
explainer_config = explainable_ai_sdk.SampledShapleyConfig()

self._attribution_explainer = explainable_ai_sdk.load_model_from_local_path(
model, explainer_config
)
self._load_model(model)


def create_lit_dataset(
Expand All @@ -172,22 +234,27 @@ def create_lit_model(
model: str,
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
attribution_method: str = "sampled_shapley",
) -> lit_model.Model:
"""Creates a LIT Model object.
Args:
model:
Required. A string reference to a local TensorFlow saved model directory.
The model must have at most one input and one output tensor.
Required. A string reference to a local TensorFlow saved model directory.
The model must have at most one input and one output tensor.
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
Returns:
A LIT Model object that has the same functionality as the model provided.
"""
return _VertexLitModel(model, input_types, output_types)
return _VertexLitModel(model, input_types, output_types, attribution_method)


def open_lit(
Expand All @@ -198,11 +265,11 @@ def open_lit(
"""Open LIT from the provided models and datasets.
Args:
models:
Required. A list of LIT models to open LIT with.
Required. A list of LIT models to open LIT with.
input_types:
Required. A lit of LIT datasets to open LIT with.
Required. A lit of LIT datasets to open LIT with.
open_in_new_tab:
Optional. A boolean to choose if LIT open in a new tab or not.
Optional. A boolean to choose if LIT open in a new tab or not.
Raises:
ImportError if LIT is not installed.
"""
Expand All @@ -216,24 +283,31 @@ def set_up_and_open_lit(
model: Union[str, lit_model.Model],
input_types: Union[List[str], Dict[str, lit_types.LitType]],
output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
attribution_method: str = "sampled_shapley",
open_in_new_tab: bool = True,
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
"""Creates a LIT dataset and model and opens LIT.
Args:
dataset:
dataset:
Required. A Pandas DataFrame that includes feature column names and data.
column_types:
column_types:
Required. An OrderedDict of string names matching the columns of the dataset
as the key, and the associated LitType of the column.
model:
model:
Required. A string reference to a TensorFlow saved model directory.
The model must have at most one input and one output tensor.
input_types:
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
open_in_new_tab:
Optional. A boolean to choose if LIT open in a new tab or not.
Returns:
A Tuple of the LIT dataset and model created.
Raises:
Expand All @@ -244,8 +318,12 @@ def set_up_and_open_lit(
dataset = create_lit_dataset(dataset, column_types)

if not isinstance(model, lit_model.Model):
model = create_lit_model(model, input_types, output_types)
model = create_lit_model(
model, input_types, output_types, attribution_method=attribution_method
)

open_lit({"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab)
open_lit(
{"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab,
)

return dataset, model
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
metadata_extra_require = ["pandas >= 1.0.0"]
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
lit_extra_require = ["tensorflow >= 2.3.0", "pandas >= 1.0.0", "lit-nlp >= 0.4.0"]
lit_extra_require = [
"tensorflow >= 2.3.0",
"pandas >= 1.0.0",
"lit-nlp >= 0.4.0",
"explainable-ai-sdk >= 1.0.0",
]
profiler_extra_require = [
"tensorboard-plugin-profile >= 2.4.0",
"werkzeug >= 2.0.0",
Expand Down
Loading