diff --git a/CHANGELOG.md b/CHANGELOG.md index aa3669c40ccf3..b39e4dab4a6f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,6 +106,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) +- Add a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 36d487e3419c7..ca10b47bd9fd2 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -546,6 +546,11 @@ def _format_batch_size_and_grad_accum_config(self): " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) if "train_micro_batch_size_per_gpu" not in self.config: + rank_zero_warn( + "Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. " + "If you require skipping this, please pass " + "`Trainer(plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`" + ) batch_size = self._auto_select_batch_size() self.config["train_micro_batch_size_per_gpu"] = batch_size self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 29086bd72a053..de4bb3ea987f9 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Dict +from typing import Any, Dict, Optional from unittest import mock import pytest @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from torchmetrics import Accuracy -from pytorch_lightning import LightningModule, seed_everything, Trainer +from pytorch_lightning import LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule @@ -808,3 +808,53 @@ def training_step(self, batch, batch_idx): trainer = Trainer(default_root_dir=tmpdir, plugins=[DeepSpeedPlugin()], gpus=1, fast_dev_run=True, precision=16) with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"): trainer.fit(model) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_warn_train_dataloader_called(tmpdir): + """Test DeepSpeed warns when it calls ``lightning_module.train_dataloader`` internally for logging batch + size.""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin()], + gpus=1, + fast_dev_run=True, + ) + with pytest.warns(UserWarning, match="Inferring the batch size for internal deepspeed logging"): + trainer.fit(model) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_setup_train_dataloader(tmpdir): + """Test DeepSpeed works when setup is required to call, and the user passes the batch size manually.""" + + class TestSetupIsCalledDataModule(LightningDataModule): + def __init__(self): + super().__init__() + self._setup = False + + def setup(self, stage: Optional[str] = None) -> None: + self._setup = True + + def train_dataloader(self): + assert self._setup + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def val_dataloader(self): + assert self._setup + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def test_dataloader(self): + assert self._setup + return DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin(logging_batch_size_per_gpu=32)], + gpus=1, + fast_dev_run=True, + ) + trainer.fit(model, datamodule=TestSetupIsCalledDataModule()) + trainer.test(model)