diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 94b46fcf00afd..0a08bd1c25234 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -155,9 +155,7 @@ def teardown(self) -> None: """ self.training_type_plugin.teardown() - def batch_to_device( - self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None - ) -> Any: + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. @@ -171,7 +169,7 @@ def batch_to_device( if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin): # no need to transfer batch to device in DP mode - return model._apply_batch_transfer_handler(batch, device, dataloader_idx) + return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) return move_data_to_device(batch, device) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4238dc124285c..21a07add45eaa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -272,7 +272,7 @@ def logger(self): return self.trainer.logger if self.trainer else None def _apply_batch_transfer_handler( - self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None + self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 ) -> Any: device = device or self.device batch = self.on_before_batch_transfer(batch, dataloader_idx) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 4850e715e1840..274e1e16d6206 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -73,13 +73,12 @@ def optimizer_freq_cumsum(self) -> int: def connect(self, **kwargs: "Loop") -> None: raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: + def run(self, batch: Any, batch_idx: int) -> AttributeDict: """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks Args: batch: the current batch to run the train step on batch_idx: the index of the current batch - dataloader_idx: the index of the dataloader producing the current batch """ if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") @@ -92,13 +91,13 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: return AttributeDict(signal=-1) # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) if response == -1: return AttributeDict(signal=-1) self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - super().run(batch, batch_idx, dataloader_idx) + super().run(batch, batch_idx) output = AttributeDict(signal=0, training_step_output=self.batch_outputs) self.batch_outputs = None # free memory return output @@ -108,26 +107,24 @@ def reset(self) -> None: self._hiddens = None self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int): + def on_run_start(self, batch: Any, batch_idx: int): """Splits the data into tbptt splits Args: batch: the current batch to run the trainstep on batch_idx: the index of the current batch - dataloader_idx: the index of the dataloader producing the current batch """ - void(batch_idx, dataloader_idx) + void(batch_idx) self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch))) - def advance(self, batch, batch_idx, dataloader_idx): + def advance(self, batch, batch_idx): """Runs the train step together with optimization (if necessary) on the current batch split Args: batch: the current batch to run the training on (this is not the split!) batch_idx: the index of the current batch - dataloader_idx: the index of the dataloader producing the current batch """ - void(batch, dataloader_idx) + void(batch) split_idx, split_batch = self._remaining_splits.pop(0) self.split_idx = split_idx diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 894f4e9197c9c..d0253f18cb496 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -22,7 +22,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.utilities.warnings import WarningCache class TrainingEpochLoop(loops.Loop): @@ -48,8 +47,6 @@ def __init__(self, min_steps: int, max_steps: int): self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) - self._dataloader_idx: Optional[int] = None - self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None @property @@ -87,7 +84,6 @@ def connect( def reset(self) -> None: """Resets the internal state of the loop for a new run""" self.is_last_batch = False - self._dataloader_idx = 0 # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] @@ -120,12 +116,12 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("training_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx) + batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, self.batch_idx, self._dataloader_idx) + batch_output = self.batch_loop.run(batch, self.batch_idx) self.batch_progress.increment_processed() @@ -143,9 +139,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook - self.trainer.call_hook( - "on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, self._dataloader_idx - ) + self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0) self.trainer.call_hook("on_batch_end") self.trainer.logger_connector.on_batch_end() diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 2ce27a1758533..f7275f87f8a46 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -392,14 +392,14 @@ class CurrentTestDM(LightningDataModule): on_after_batch_transfer_hook_rank = None def on_before_batch_transfer(self, batch, dataloader_idx): - assert dataloader_idx is None + assert dataloader_idx == 0 self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 batch.samples += 1 return batch def on_after_batch_transfer(self, batch, dataloader_idx): - assert dataloader_idx is None + assert dataloader_idx == 0 assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 @@ -407,7 +407,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): return batch def transfer_batch_to_device(self, batch, device, dataloader_idx): - assert dataloader_idx is None + assert dataloader_idx == 0 self.transfer_batch_to_device_hook_rank = self.rank self.rank += 1 batch.samples = batch.samples.to(device) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 2f5fa71ede9f4..990178b09d07f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -128,14 +128,14 @@ class CurrentTestModel(BoringModel): on_after_batch_transfer_hook_rank = None def on_before_batch_transfer(self, batch, dataloader_idx): - assert dataloader_idx is None + assert dataloader_idx == 0 self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 batch.samples += 1 return batch def on_after_batch_transfer(self, batch, dataloader_idx): - assert dataloader_idx is None + assert dataloader_idx == 0 assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 @@ -143,7 +143,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): return batch def transfer_batch_to_device(self, batch, device, dataloader_idx): - assert dataloader_idx is None + assert dataloader_idx == 0 self.transfer_batch_to_device_hook_rank = self.rank self.rank += 1 batch.samples = batch.samples.to(device) diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 37f6a7021c76e..6c0b7b7ed40d7 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -68,7 +68,7 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) assert out.signal == 0 train_step_out = out.training_step_output @@ -134,7 +134,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) assert out.signal == 0 train_step_out = out.training_step_output diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 2660b3ad13094..4ee9d858d44c9 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -146,7 +146,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) assert out.signal == 0 train_step_out = out.training_step_output @@ -219,7 +219,7 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected batch_idx, batch = 0, next(iter(model.train_dataloader())) - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) assert out.signal == 0 train_step_out = out.training_step_output @@ -300,7 +300,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0 @@ -344,7 +344,7 @@ def train_dataloader(self): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0) + out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) if not batch_idx % 2: assert out.training_step_output == [[]] assert out.signal == 0