-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add basic explain get_metadata function for tf2. #507
Conversation
/hold |
/unhold |
self, | ||
model_path: str, | ||
signature_name: Optional[str] = None, | ||
outputs_to_explain: Optional[List[str]] = (), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This Optional type is a List while the default is an empty tuple. It seems like the default can be None
instead.
google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py
Show resolved
Hide resolved
if tensor_spec.dtype.is_floating: | ||
input_mds[ | ||
name | ||
] = explanation_metadata.ExplanationMetadata.InputMetadata( | ||
input_tensor_name=name | ||
) | ||
else: | ||
input_mds[ | ||
name | ||
] = explanation_metadata.ExplanationMetadata.InputMetadata( | ||
input_tensor_name=name, modality="categorical", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Option to reduce the size of this block.
if tensor_spec.dtype.is_floating: | |
input_mds[ | |
name | |
] = explanation_metadata.ExplanationMetadata.InputMetadata( | |
input_tensor_name=name | |
) | |
else: | |
input_mds[ | |
name | |
] = explanation_metadata.ExplanationMetadata.InputMetadata( | |
input_tensor_name=name, modality="categorical", | |
) | |
input_mds[name] = explanation_metadata.ExplanationMetadata.InputMetadata( | |
input_tensor_name=name, | |
modality=None if tensor_spec.dtype.is_floating else "categorical" | |
) |
setup.py
Outdated
@@ -31,7 +31,10 @@ | |||
|
|||
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"] | |||
metadata_extra_require = ["pandas >= 1.0.0"] | |||
full_extra_require = tensorboard_extra_require + metadata_extra_require | |||
xai_extra_require = ["tensorflow-cpu>=2.3.0, <=2.5.0"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"tensorflow" is the CPU-only release: https://www.tensorflow.org/install/pip
This will also make it easier to manage with TB's TF dependency.
setup.py
Outdated
@@ -31,7 +31,10 @@ | |||
|
|||
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"] | |||
metadata_extra_require = ["pandas >= 1.0.0"] | |||
full_extra_require = tensorboard_extra_require + metadata_extra_require | |||
xai_extra_require = ["tensorflow-cpu>=2.3.0, <=2.5.0"] | |||
full_extra_require = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensorboard and xai have the same dependencies so we may need to deduplicate this with list(set(...))
.
|
||
|
||
class SavedModelMetadataBuilderTest(tf.test.TestCase): | ||
def _set_up(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is only used by one test, it should perhaps not be a shared method.
Basic explain get_metadata function, for tf2.
Usage:
The main difference from Vertex SDK and XAI SDK's
get_metadata
function:In Vertex SDK, users the camelCase while in XAI, variables users snake_case
e.g. Vertex:
inputTensorName
, XAI:input_tensor_name
If modality and encoding in inputs are not set, XAI SDK populates the default:
Vertex SDK doesn't populate the default values. It only populates the users's settings.
The two methods will not make any difference in the actual settings.