-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consume the prediction batch indices iteratively #16826
Conversation
⚡ Required checks status: All passing 🟢Groups summary🟢 pytorch_lightning: Tests workflow
These checks are required after the changes to 🟢 pytorch_lightning: Azure GPU
These checks are required after the changes to 🟢 pytorch_lightning: Azure HPU
These checks are required after the changes to 🟢 pytorch_lightning: Azure IPU
These checks are required after the changes to 🟢 pytorch_lightning: Docs
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 installThese checks are required after the changes to 🟢 link-check
These checks are required after the changes to Thank you for your contribution! 💜
|
What does this PR do?
During prediction, the
IndexBatchSamplerWrapper
class was consuming the entirebatch_sampler
on__iter__
, at the very beginningThis PR makes it get consumed on the go.
It also fixes an issue where we were trying to call
batch_sampler.set_epoch
, when it had to bebatch_sampler.sampler.set_epoch
.Additionally, the
IndexBatchSamplerWrapper
is now marked as protected.cc @Borda @justusschock @awaelchli