Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom prediction function in colab and jupyter notebook modes. #1842

Merged
merged 7 commits into from
Feb 14, 2019

Conversation

tolga-b
Copy link
Contributor

@tolga-b tolga-b commented Feb 13, 2019

  • Motivation for features / changes

Add custom prediction function support in colab and jupyter mode so it is easier to load non-TensorFlow models into WIT.

  • Technical description of changes

WITConfigBuilder now has a "set_custom_predict_fn" method to pass a python function for prediction purposes. Parts of notebook backend and inference_utils are updated to use the custom function when provided.

  • Screenshots of UI changes

No UI changes.

  • Detailed steps to verify changes work correctly (as executed by you)

Ran WIT Toxicity Text Model Comparison and WIT Model Comparison colabs with custom_predict_fn to verify the tool is getting inferences correctly.

  • Alternate designs / implementations considered

@tolga-b
Copy link
Contributor Author

tolga-b commented Feb 13, 2019

@jameswex please review.

@@ -57,6 +57,11 @@ You can use the What-If Tool to analyze a classification or regression
that takes TensorFlow Example or SequenceExample protos
(data points) as inputs directly in a jupyter or colab notebook.

You can also use What-If-Tool with your own custom prediction function that takes
Tensorflow Example and produces predictions. In this mode, you can load any model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say "takes TensorFlow examples and produces predictions" (that way it includes Examples and SequenceExamples)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@@ -181,14 +181,16 @@ class ServingBundle(object):
Predict API.
estimator: An estimator to use instead of calling an external model.
feature_spec: A feature spec for use with the estimator.
custom_predict_fn: A custom prediction function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add period

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -336,14 +340,18 @@ def set_estimator_and_feature_spec(self, estimator, feature_spec):
Returns:
self, in order to enabled method chaining.
"""
# If custom function is set, remove it before setting estimator
self.delete('custom_predict_fn')
self.delete('compare_custom_predict_fn')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should jsut delete the custom_predict_fn, not the compare one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, makes sense if you want to compare an estimator to a predict function.

@@ -363,10 +371,77 @@ def set_compare_estimator_and_feature_spec(self, estimator, feature_spec):
Returns:
self, in order to enabled method chaining.
"""
# If custom function is set, remove it before setting estimator
self.delete('custom_predict_fn')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should only delete the custom compare predict fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
# If estimator is set, remove it before setting predict_fn
self.delete('estimator_and_spec')
self.delete('compare_estimator_and_spec')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self, in order to enabled method chaining.
"""
# If estimator is set, remove it before setting predict_fn
self.delete('estimator_and_spec')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@jameswex jameswex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just some small comments, otherwise looks good

@tolga-b
Copy link
Contributor Author

tolga-b commented Feb 13, 2019

Thanks James, I added all your comments.

@jameswex jameswex merged commit 85519e7 into tensorflow:master Feb 14, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants