-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable us to use sklearn to do cv for functional api #9320
Conversation
keras/wrappers/scikit_learn.py
Outdated
if isinstance(self.model, Sequential): | ||
classes = self.model.predict_classes(x, **kwargs) | ||
else: | ||
classes = np.round(self.model.predict(x, **kwargs)).astype(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rounding predictions to int does not, in fact, predict classes (unless in the specific case of a binary classification problem). Please check the implementation of predict_classes
on Sequential
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking time to review the commit! I modified the code again referring to the implementation of "predict_classes" on "Sequential".
keras/wrappers/scikit_learn.py
Outdated
@@ -247,7 +252,8 @@ def predict_proba(self, x, **kwargs): | |||
(instead of `(n_sample, 1)` as in Keras). | |||
""" | |||
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs) | |||
probs = self.model.predict_proba(x, **kwargs) | |||
# check if the model is sequential or functional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is no longer relevant?
* 'master' of github.com:fchollet/keras: (57 commits) Minor README edit Speed up Travis tests (keras-team#9386) fix typo (keras-team#9391) Fix style issue in docstring Prepare 2.1.4 release. Fix activity regularizer + model composition test Corrected copyright years (keras-team#9375) Change default interpolation from nearest to bilinear. (keras-team#8849) a capsule cnn on cifar-10 (keras-team#9193) Enable us to use sklearn to do cv for functional api (keras-team#9320) Add support for stateful metrics. (keras-team#9253) The type of list keys was float (keras-team#9324) Fix mnist sklearn wrapper example (keras-team#9317) keras-team#9287 Fix most of the file-handle resource leaks. (keras-team#9309) Pass current learning rate to schedule() in LearningRateScheduler (keras-team#8865) Simplify with from six.moves import input (keras-team#9216) fixed RemoteMonitor: Json to handle np.float32 and np.int32 types (keras-team#9261) Update tweet length from 140 to 280 in docs Add `depthconv_conv2d` tests (keras-team#9225) Remove `force` option in progbar ...
E.g. This modification can enable us to do cross validation for models created by functional api with sklearn.model_selection.cross_validate.