-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix DeepSpeedPlugin with IterableDataset #7362
Conversation
Codecov Report
@@ Coverage Diff @@
## master #7362 +/- ##
=======================================
- Coverage 92% 87% -5%
=======================================
Files 200 200
Lines 12983 12992 +9
=======================================
- Hits 11937 11359 -578
- Misses 1046 1633 +587 |
Thanks @leezu! Let me know your thoughts on the proposed change, and I can help make a test for this to get this merged. |
I added a test, and modified the code to now use |
if hasattr(self.lightning_module, 'train_dataloader'): | ||
train_dataloader = self.lightning_module.train_dataloader() | ||
if hasattr(train_dataloader, 'batch_sampler'): | ||
batch_size = train_dataloader.batch_sampler.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anyone can suggest anything cleaner please do :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably not much as train_dataloader()
is callable not just an attribute...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could break if the user provides several dataloaders to the CombinedLoader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@leezu ^^
@leezu pls use and follow the bullet list from template |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
if hasattr(self.lightning_module, 'train_dataloader'): | ||
train_dataloader = self.lightning_module.train_dataloader() | ||
if hasattr(train_dataloader, 'batch_sampler'): | ||
batch_size = train_dataloader.batch_sampler.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could break if the user provides several dataloaders to the CombinedLoader.
Thank you @SeanNaren. LGTM |
Thanks @leezu :) |
* deepspeed add train_micro_batch_size_per_gpu argument * Update naming and doc * Modify to use auto naming convention, add test * Add iterable tests * Fix tests, attempt by mocking * Import correct package * Fix comparison * Set as special test * Remove import * Add Changelog Co-authored-by: SeanNaren <[email protected]> (cherry picked from commit 98b94b8)
* deepspeed add train_micro_batch_size_per_gpu argument * Update naming and doc * Modify to use auto naming convention, add test * Add iterable tests * Fix tests, attempt by mocking * Import correct package * Fix comparison * Set as special test * Remove import * Add Changelog Co-authored-by: SeanNaren <[email protected]> (cherry picked from commit 98b94b8)
* deepspeed add train_micro_batch_size_per_gpu argument * Update naming and doc * Modify to use auto naming convention, add test * Add iterable tests * Fix tests, attempt by mocking * Import correct package * Fix comparison * Set as special test * Remove import * Add Changelog Co-authored-by: SeanNaren <[email protected]> (cherry picked from commit 98b94b8)
* deepspeed add train_micro_batch_size_per_gpu argument * Update naming and doc * Modify to use auto naming convention, add test * Add iterable tests * Fix tests, attempt by mocking * Import correct package * Fix comparison * Set as special test * Remove import * Add Changelog Co-authored-by: SeanNaren <[email protected]> (cherry picked from commit 98b94b8)
Fixes #7345
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃