Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable us to use sklearn to do cv for functional api #9320

Merged
merged 4 commits into from
Feb 8, 2018

Conversation

XinsongDu
Copy link
Contributor

E.g. This modification can enable us to do cross validation for models created by functional api with sklearn.model_selection.cross_validate.

if isinstance(self.model, Sequential):
classes = self.model.predict_classes(x, **kwargs)
else:
classes = np.round(self.model.predict(x, **kwargs)).astype(int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rounding predictions to int does not, in fact, predict classes (unless in the specific case of a binary classification problem). Please check the implementation of predict_classes on Sequential.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking time to review the commit! I modified the code again referring to the implementation of "predict_classes" on "Sequential".

@@ -247,7 +252,8 @@ def predict_proba(self, x, **kwargs):
(instead of `(n_sample, 1)` as in Keras).
"""
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
probs = self.model.predict_proba(x, **kwargs)
# check if the model is sequential or functional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is no longer relevant?

@fchollet fchollet merged commit d9f26a9 into keras-team:master Feb 8, 2018
ahundt added a commit to ahundt/keras that referenced this pull request Feb 16, 2018
* 'master' of github.com:fchollet/keras: (57 commits)
  Minor README edit
  Speed up Travis tests (keras-team#9386)
  fix typo (keras-team#9391)
  Fix style issue in docstring
  Prepare 2.1.4 release.
  Fix activity regularizer + model composition test
  Corrected copyright years (keras-team#9375)
  Change default interpolation from nearest to bilinear. (keras-team#8849)
  a capsule cnn on cifar-10 (keras-team#9193)
  Enable us to use sklearn to do cv for functional api (keras-team#9320)
  Add support for stateful metrics. (keras-team#9253)
  The type of list keys was float (keras-team#9324)
  Fix mnist sklearn wrapper example (keras-team#9317)
  keras-team#9287 Fix most of the file-handle resource leaks. (keras-team#9309)
  Pass current learning rate to schedule() in LearningRateScheduler (keras-team#8865)
  Simplify with from six.moves import input (keras-team#9216)
  fixed RemoteMonitor: Json to handle np.float32 and np.int32 types (keras-team#9261)
  Update tweet length from 140 to 280 in docs
  Add `depthconv_conv2d` tests (keras-team#9225)
  Remove `force` option in progbar
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants