Skip to content

Commit

Permalink
Reset dataloaders on failure in tuner (#14372)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored and lexierule committed Sep 13, 2022
1 parent 582b8cc commit 6bd71be
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
1 change: 0 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- When using multiple loggers, by default checkpoints and profiler output now get saved to the log dir of the first logger in the list ([#14325](https://github.com/Lightning-AI/lightning/pull/14325))
- Improved the error messaging when passing `Trainer.method(model, x_dataloader=None)` with no module-method implementations available ([#14614](https://github.com/Lightning-AI/lightning/pull/14614))

### Fixed
Expand Down
38 changes: 26 additions & 12 deletions src/pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def _run_power_scaling(
trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int
) -> int:
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
# if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
any_success = False
for _ in range(max_trials):
garbage_collection_cuda()

Expand All @@ -137,22 +140,28 @@ def _run_power_scaling(
trainer.tuner._run(model)
# Double in size
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")

if not changed:
break

# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
any_success = True
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc="failed")
break
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
if any_success:
break
else:
raise # some other error not memory related

if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
else:
break
return new_size


Expand Down Expand Up @@ -189,13 +198,13 @@ def _run_binsearch_scaling(
else:
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")

if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
else:
if not changed:
break

# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)

except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
Expand All @@ -204,6 +213,11 @@ def _run_binsearch_scaling(
high = new_size
midval = (high + low) // 2
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed")

# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)

if high - low <= 1:
break
else:
Expand Down
23 changes: 23 additions & 0 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,26 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):

assert trainer.train_dataloader.loaders.batch_size == new_batch_size
assert trainer.val_dataloaders[0].batch_size == new_batch_size


@pytest.mark.parametrize("scale_method, expected_batch_size", [("power", 62), ("binsearch", 100)])
@patch("pytorch_lightning.tuner.batch_size_scaling.is_oom_error", return_value=True)
def test_dataloader_batch_size_updated_on_failure(_, tmpdir, scale_method, expected_batch_size):
class CustomBatchSizeModel(BatchSizeModel):
def training_step(self, *_, **__):
if self.batch_size > 100:
raise RuntimeError

def train_dataloader(self):
return DataLoader(RandomDataset(32, 1000), batch_size=self.batch_size)

model = CustomBatchSizeModel(batch_size=16)
model.validation_step = None
model.training_epoch_end = None
scale_batch_size_kwargs = {"max_trials": 10, "steps_per_trial": 1, "init_val": 500, "mode": scale_method}

trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, auto_scale_batch_size=True)
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
assert new_batch_size == model.batch_size
assert new_batch_size == expected_batch_size
assert trainer.train_dataloader.loaders.batch_size == expected_batch_size

0 comments on commit 6bd71be

Please sign in to comment.