Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a warning to deepspeed when inferring batch size #9221

Merged
merged 8 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))


- Add a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221))


### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,11 @@ def _format_batch_size_and_grad_accum_config(self):
" as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
)
if "train_micro_batch_size_per_gpu" not in self.config:
rank_zero_warn(
"Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. "
"If you require skipping this, please pass "
"`Trainer(plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size)`"
)
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
Expand Down
57 changes: 55 additions & 2 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Dict
from typing import Any, Dict, Optional
from unittest import mock

import pytest
Expand All @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning import LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
Expand Down Expand Up @@ -830,3 +830,56 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir):
trainer.fit(model)

_assert_save_model_is_equal(model, tmpdir, trainer)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_warn_train_dataloader_called(tmpdir):
"""
Test DeepSpeed warns when it calls ``train_dataloader`` internally for logging batch size.
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin()],
gpus=1,
fast_dev_run=True,
)
with pytest.warns(UserWarning, match="Inferring the batch size for internal deepspeed logging"):
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_setup_train_dataloader(tmpdir):
"""
Test DeepSpeed works when setup is required to call, and the user passes the batch size manually.
"""

class TestSetupIsCalledDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self._setup = False

def setup(self, stage: Optional[str] = None) -> None:
self._setup = True

def train_dataloader(self):
assert self._setup
return DataLoader(RandomDataset(32, 64), batch_size=2)

def val_dataloader(self):
assert self._setup
return DataLoader(RandomDataset(32, 64), batch_size=2)

def test_dataloader(self):
assert self._setup
return DataLoader(RandomDataset(32, 64), batch_size=2)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin(logging_batch_size_per_gpu=32)],
gpus=1,
fast_dev_run=True,
)
trainer.fit(model, datamodule=TestSetupIsCalledDataModule())
trainer.test(model)