Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the logic to check for accumulation steps with deepspeed #9826

Merged
merged 4 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise an exception if using `amp_level` with native `amp_backend` ([#9755](https://github.com/PyTorchLightning/pytorch-lightning/pull/9755))


- Update the logic to check for accumulation steps with deepspeed ([#9826](https://github.com/PyTorchLightning/pytorch-lightning/pull/9826))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,11 @@ def pre_dispatch(self):
self.barrier()

def init_deepspeed(self):
accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
if not isinstance(accumulate_grad_batches, int):
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

if accumulation_scheduler.epochs != [0]:
raise MisconfigurationException(
"DeepSpeed currently only supports `Trainer.accumulate_grad_batches` being an integer."
f" Received {accumulate_grad_batches}"
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epoch."
)

precision = self.lightning_module.trainer.accelerator.precision
Expand Down
2 changes: 1 addition & 1 deletion tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da


@RunIf(ipu=True)
def test_accumulate_grad_batches_dict_fails(tmpdir):
def test_different_accumulate_grad_batches_fails(tmpdir):
model = IPUModel()
trainer = Trainer(default_root_dir=tmpdir, ipus=1, accumulate_grad_batches={1: 2})
with pytest.raises(
Expand Down
10 changes: 10 additions & 0 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,3 +961,13 @@ def configure_optimizers(self):
else:
# assert called once at init and once during training
assert mock_step.call_count == 1 + (max_epoch * limit_train_batches)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_different_accumulate_grad_batches_fails(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, accumulate_grad_batches={1: 2}, gpus=1, plugins="deepspeed")
with pytest.raises(
MisconfigurationException, match="DeepSpeed currently does not support different `accumulate_grad_batches`"
):
trainer.fit(model)