Skip to content

Commit

Permalink
feat: add explanation metadata object for reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-yaqi committed Sep 1, 2021
1 parent 179150a commit 753344d
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 8 deletions.
20 changes: 17 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ You can also create a batch prediction job asynchronously by including the `sync
batch_prediction_job.state
# block until job is complete
batch_prediction_job.wait()
batch_prediction_job.wait()
Endpoints
Expand Down Expand Up @@ -352,7 +352,7 @@ To delete an endpoint:
Explainable AI: Get Metadata
----------------------------

To get metadata from TensorFlow 1 models:
To get metadata in dictionary format from TensorFlow 1 models:

.. code-block:: Python
Expand All @@ -363,7 +363,7 @@ To get metadata from TensorFlow 1 models:
)
generated_md = builder.get_metadata()
To get metadata from TensorFlow 2 models:
To get metadata in dictionary format from TensorFlow 2 models:

.. code-block:: Python
Expand All @@ -372,6 +372,20 @@ To get metadata from TensorFlow 2 models:
builder = saved_model_metadata_builder.SavedModelMetadataBuilder('gs://python/to/my/model/dir')
generated_md = builder.get_metadata()
To use Explanation Metadata in endpoint deployment and model upload:

.. code-block:: Python
explanation_metadata = builder.get_metadata_object()
# To deploy a Model to an Endpoint with explanation
model.deploy(..., explanation_metadata=explanation_metadata)
# To deploy a model to a created endpoint with explanation
endpoint.deploy(..., explanation_metadata=explanation_metadata)
# To upload a model with explanation
aiplatform.Model.upload(..., explanation_metadata=explanation_metadata)
Next Steps
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/explain/metadata/metadata_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ class MetadataBuilder(_ABC):
@abc.abstractmethod
def get_metadata(self):
"""Returns the current metadata as a dictionary."""

@abc.abstractmethod
def get_metadata_object(self):
"""Returns the current metadata as ExplanationMetadata object"""
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,17 @@ def get_metadata(self) -> Dict[str, Any]:
Returns:
Json format of the explanation metadata.
"""
current_md = explanation_metadata.ExplanationMetadata(
return json_format.MessageToDict(self.get_metadata_object()._pb)

def get_metadata_object(self) -> explanation_metadata.ExplanationMetadata:
"""Returns the current metadata as an object.
Returns:
ExplanationMetadata object format of the explanation metadata.
"""
return explanation_metadata.ExplanationMetadata(
inputs=self._inputs, outputs=self._outputs,
)
return json_format.MessageToDict(current_md._pb)


def _create_input_metadata_from_signature(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,14 @@ def get_metadata(self) -> Dict[str, Any]:
Returns:
Json format of the explanation metadata.
"""
current_md = explanation_metadata.ExplanationMetadata(
return json_format.MessageToDict(self.get_metadata_object()._pb)

def get_metadata_object(self) -> explanation_metadata.ExplanationMetadata:
"""Returns the current metadata as an object.
Returns:
ExplanationMetadata object format of the explanation metadata.
"""
return explanation_metadata.ExplanationMetadata(
inputs=self._inputs, outputs=self._outputs,
)
return json_format.MessageToDict(current_md._pb)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import tensorflow.compat.v1 as tf

from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)


class SavedModelMetadataBuilderTF1Test(tf.test.TestCase):
Expand Down Expand Up @@ -68,6 +71,18 @@ def test_get_metadata_correct_inputs(self):

assert md_builder.get_metadata() == expected_md

def test_get_metadata_object_correct_inputs(self):
self._set_up()
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.model_path, tags=[tf.saved_model.tag_constants.SERVING]
)
expected_object = explanation_metadata.ExplanationMetadata(
inputs={"x": {"input_tensor_name": "inp:0"}},
outputs={"y": {"output_tensor_name": "Relu:0"}},
)

assert md_builder.get_metadata_object() == expected_object

def test_get_metadata_double_output(self):
self._set_up()
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
Expand All @@ -80,3 +95,16 @@ def test_get_metadata_double_output(self):
}

assert md_builder.get_metadata() == expected_md

def test_get_metadata_object_double_output(self):
self._set_up()
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.model_path, signature_name="double", outputs_to_explain=["lin"]
)

expected_object = explanation_metadata.ExplanationMetadata(
inputs={"x": {"input_tensor_name": "inp:0"}},
outputs={"lin": {"output_tensor_name": "Add:0"}},
)

assert md_builder.get_metadata_object() == expected_object
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import numpy as np

from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
from google.cloud.aiplatform.compat.types import (
explanation_metadata_v1beta1 as explanation_metadata,
)


class SavedModelMetadataBuilderTF2Test(tf.test.TestCase):
def test_get_metadata_sequential(self):
def _set_up_sequential(self):
# Set up for the sequential.
self.seq_model = tf.keras.models.Sequential()
self.seq_model.add(tf.keras.layers.Dense(32, activation="relu", input_dim=10))
Expand All @@ -32,6 +35,9 @@ def test_get_metadata_sequential(self):
self.saved_model_path = self.get_temp_dir()
tf.saved_model.save(self.seq_model, self.saved_model_path)

def test_get_metadata_sequential(self):
self._set_up_sequential()

builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.saved_model_path
)
Expand All @@ -42,6 +48,19 @@ def test_get_metadata_sequential(self):
}
assert expected_md == generated_md

def test_get_metadata_sequential(self):
self._set_up_sequential()

builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
self.saved_model_path
)
generated_object = builder.get_metadata_object()
expected_object = explanation_metadata.ExplanationMetadata(
inputs={"dense_input": {"input_tensor_name": "dense_input"}},
outputs={"dense_2": {"output_tensor_name": "dense_2"}},
)
assert expected_object == generated_object

def test_get_metadata_functional(self):
inputs1 = tf.keras.Input(shape=(10,), name="model_input1")
inputs2 = tf.keras.Input(shape=(10,), name="model_input2")
Expand Down

0 comments on commit 753344d

Please sign in to comment.