Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Always return batch indices to prevent duplicated logic for the users #9432

Merged
merged 8 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))


- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))


- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512))


Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,13 @@ def on_predict_batch_end(
) -> None:
if not self.interval.on_batch:
return
is_distributed = trainer.accelerator_connector.is_distributed
batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices if is_distributed else None
batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices
self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)

def on_predict_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Sequence[Any]
) -> None:
if not self.interval.on_epoch:
return
is_distributed = trainer.accelerator_connector.is_distributed
epoch_batch_indices = trainer.predict_loop.epoch_batch_indices if is_distributed else None
epoch_batch_indices = trainer.predict_loop.epoch_batch_indices
self.write_on_epoch_end(trainer, pl_module, trainer.predict_loop.predictions, epoch_batch_indices)
7 changes: 6 additions & 1 deletion pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class PredictionEpochLoop(Loop):
"""Loop performing prediction on arbitrary sequentially used dataloaders."""
Expand Down Expand Up @@ -160,8 +162,11 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict

def _store_batch_indices(self, dataloader_idx: int) -> None:
"""Stores the batch indices if the predictions should be stored."""
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
current_dataloader = self.trainer.predict_dataloaders[dataloader_idx]
batch_sampler = getattr(current_dataloader, "batch_sampler", None)
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
self.current_batch_indices = batch_sampler.batch_indices
if self.should_store_predictions:
self._all_batch_indices.append(batch_sampler.batch_indices)
else:
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
36 changes: 26 additions & 10 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from omegaconf import OmegaConf
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset

import tests.helpers.utils as tutils
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
Expand All @@ -46,6 +46,7 @@
from pytorch_lightning.utilities.seed import seed_everything
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.boring_model import RandomIterableDataset, RandomIterableDatasetWithLen
from tests.helpers.runif import RunIf


Expand Down Expand Up @@ -1287,21 +1288,15 @@ def __init__(self, output_dir: str, *args, **kwargs):

def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *args, **kwargs):
assert prediction.shape == torch.Size([1, 2])
if trainer.accelerator_connector.is_distributed:
assert len(batch_indices) == 1
else:
assert batch_indices is None
assert len(batch_indices) == 1
self.write_on_batch_end_called = True

def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
expected = 1 if trainer.accelerator_connector.is_distributed else 2
assert len(predictions) == 2
assert len(predictions[0]) == expected
if trainer.accelerator_connector.is_distributed:
assert len(batch_indices) == 2
assert len(batch_indices[0]) == expected
else:
assert batch_indices is None
assert len(batch_indices) == 2
assert len(batch_indices[0]) == expected
self.write_on_epoch_end_called = True

def on_predict_epoch_end(self, trainer, pl_module, outputs):
Expand Down Expand Up @@ -1416,6 +1411,27 @@ def test_trainer_predict_ddp_cpu(tmpdir):
predict(tmpdir, "ddp_cpu", 0, 2)


@pytest.mark.parametrize("dataset_cls", [RandomDataset, RandomIterableDatasetWithLen, RandomIterableDataset])
def test_index_batch_sampler_wrapper_with_iterable_dataset(dataset_cls, tmpdir):

ds = dataset_cls(32, 8)
loader = DataLoader(ds)
is_iterable_dataset = isinstance(ds, IterableDataset)

class CustomPredictionWriter(BasePredictionWriter):
def __init__(self, output_dir: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.output_dir = output_dir

def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *args, **kwargs):
assert not batch_indices if is_iterable_dataset else batch_indices

cb = CustomPredictionWriter(tmpdir)
trainer = Trainer(default_root_dir=tmpdir, callbacks=cb)
predictions = trainer.predict(BoringModel(), dataloaders=loader)
assert len(predictions) == 8


@patch("torch.cuda.device_count", return_value=2)
@patch("torch.cuda.is_available", return_value=True)
def test_spawn_predict_return_predictions(*_):
Expand Down