From c7d22febb03b48ecb68e9813d02e3ef98578f2e0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 20 Aug 2021 13:23:10 +0100 Subject: [PATCH 01/74] update --- .../processors/iterator_batch_processor.py | 27 +- .../trainer/connectors/data_connector.py | 3 +- pytorch_lightning/utilities/fetching.py | 238 ++++++++++++++++-- tests/utilities/test_fetching.py | 162 +++++++++++- 4 files changed, 385 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 4dda4fe596698..7e12602aa9a2c 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from collections import OrderedDict from copy import copy @@ -75,11 +74,11 @@ def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: "The model hook `tbptt_split_batch` is not compatible with " "taking a `dataloader_iter` argument in your `training_step`." ) - if model.automatic_optimization: - raise MisconfigurationException( - "`automatic_optimization` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) + # if model.automatic_optimization: + # raise MisconfigurationException( + # "`automatic_optimization` is not compatible with " + # "taking a `dataloader_iter` argument in your `training_step`." + # ) if trainer.accumulate_grad_batches != 1: raise MisconfigurationException( "`accumulate_grad_batches` can only be 1 when your " @@ -119,9 +118,11 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: Args: dataloader_iter: the iterator over the dataloader producing the new batch """ - dataloader_iter = itertools.starmap( - lambda batch_idx, batch_with_is_last: batch_with_is_last[0], dataloader_iter - ) + # dataloader_iter = itertools.starmap( + # lambda batch_idx, batch_with_is_last: batch_with_is_last[0], dataloader_iter + # ) + + _, (dataloader_iter, batch_idx, is_last) = next(dataloader_iter) self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") @@ -137,19 +138,13 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: # manually capture logged metrics model._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): - step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)]) + step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter), ("batch_idx", batch_idx)]) training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() training_step_output = self.trainer.call_hook("training_step_end", training_step_output) _check_training_step_output(self.trainer.lightning_module, training_step_output) - if training_step_output is None or "is_last" not in training_step_output: - raise MisconfigurationException( - "When `training_step` takes `dataloader_iter` as an argument, the result dict must " - "contain a `is_last` field to indicate whether there are more batches to be processed." - ) - is_last = training_step_output["is_last"] training_step_output, _ = _process_training_step_output(self.trainer, training_step_output) if self.trainer.terminate_on_nan: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 629be97f29468..089305dca7c33 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -26,7 +26,7 @@ class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - self.data_fetcher: Optional[DataFetcher] = None + self.data_fetcher: Optional[DataFetcher] = DataFetcher() def on_trainer_init( self, @@ -61,7 +61,6 @@ def on_trainer_init( self.trainer._is_data_prepared = False def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: - self.data_fetcher = DataFetcher() self.data_fetcher.setup(train_dataloader) prefetcher_iter = iter(self.data_fetcher) profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch") diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 42de4f8571e3e..395e941fce653 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -14,12 +14,15 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator +from contextlib import contextmanager from copy import deepcopy from functools import partial -from typing import Any, Generator, List, Optional, Tuple +from typing import Any, Callable, Generator, List, Optional, Tuple +import torch from torch.utils.data.dataloader import DataLoader +import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( @@ -35,18 +38,27 @@ class AbstractDataFetcher(ABC): """ - This class is used to control batch fetching flow. + This based class should be used to implement a fault tolerant `DataFetcher`. + It is required to override the ``fetching_function`` with fetching logic. + Example:: + class SimpleDataFetcher(AbstractDataFetcher): + def fetching_function(self): + while True: + try: + yield next(self.dataloader_iter), False + except StopIteration: + yield None, True """ @abstractmethod def fetching_function(self) -> Generator: - pass + """Override with your own fetching logic.""" def __init__( self, prefetch_batches: int = 0, ) -> None: - if not isinstance(prefetch_batches, int) or (isinstance(prefetch_batches, int) and prefetch_batches < 0): + if prefetch_batches < 0: raise MisconfigurationException("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches + 1 @@ -54,15 +66,32 @@ def __init__( self.dataloader: Optional[Iterable] = None self.dataloader_iter: Optional[Iterator] = None + self.stage: Optional[str] + self.batch_to_device: Optional[Callable] + self.profiler: "Optional[pl.profiler.base.BaseProfiler]" + self.batches: List self.fetched: int self.done: bool self.reset() - def setup(self, dataloader: DataLoader, **kwargs) -> None: + def setup( + self, + dataloader: Iterable, + stage: Optional[str] = None, + batch_to_device: Optional[Callable] = None, + profiler: "Optional[pl.profiler.base.BaseProfiler]" = None, + ) -> None: self._add_capture_metadata_collate(dataloader) + self.dataloader = dataloader + self.stage = stage + self.batch_to_device = batch_to_device + self.profiler = profiler + + if self.profiler is not None and stage is None: + raise MisconfigurationException("When providing a profiler, the stage should be provided too.") @staticmethod def _add_capture_metadata_collate(dataloader: Iterable) -> None: @@ -78,10 +107,10 @@ def add_capture_metadata_collate(dataloader: DataLoader): apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate) - def add_batch(self, batch) -> None: + def append_batch(self, batch) -> None: self.batches.append(batch) - def fetch_batch(self) -> Any: + def pop_batch(self) -> Any: return self.batches.pop(0) def _apply_patch(self): @@ -169,40 +198,201 @@ def reset(self) -> None: self.fetched: int = 0 self.done: bool = False + def teardown(self) -> None: + self.reset() + class DataFetcher(AbstractDataFetcher): """ This class is used to control batch fetching flow. + By default, the `fetching_function` will `prefetch` a batch in advance to detect the end of the iteration. + Args: + prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch + at least 1 batch for tracking the latest batch. + store_on_gpu: Whether to store the pre-fetched batches on device. """ + def __init__( + self, + prefetch_batches: int = 0, + store_on_gpu: bool = False, + ) -> None: + super().__init__(prefetch_batches=prefetch_batches) + self.store_on_gpu = store_on_gpu + + @contextmanager + def fetching_context(self): + """Hook to override to add context logic around batch fetching""" + yield + + def on_fetch_start(self) -> None: + """Hook to override to handle the logic before fetching a batch""" + + def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> None: + """Hook to extend which handles the logic after fetching a batch""" + if self.store_on_gpu: + batch = self.move_data_to_device(batch) + self.append_batch(batch) + + def wait(self) -> None: + """Hook to override to indicate the `DataFetcher` to wait for an event.""" + def fetching_function(self) -> Generator: self.done = False while not self.done: self._prefetching(self.prefetch_batches) - for batch in self.dataloader_iter: - yield_batch = self.fetch_batch() - self.add_batch(batch) - self.fetched += 1 - # yield last and has next - yield yield_batch, False + while self.batches: + try: + yield_batch = self.pop_batch() + self._fetch_next_batch() - yield from self._consume_prefetched_batches() + # yield last and has next + self.wait() - def _consume_prefetched_batches(self) -> Generator: - self.done = True - while self.batches: - if len(self.batches) == 1: - yield self.batches.pop(0), True - else: - yield self.batches.pop(0), False + yield (self.move_data_to_device(yield_batch) if not self.store_on_gpu else yield_batch, False) + except StopIteration: + self.batches.insert(0, yield_batch) + break + + yield from self._consume_prefetched_batches() def _prefetching(self, prefetch_batches: int) -> None: for _ in range(prefetch_batches): try: - batch = next(self.dataloader_iter) - self.fetched += 1 - self.add_batch(batch) + self._fetch_next_batch() except StopIteration: break + + @contextmanager + def apply_profiler(self, name: str) -> Generator: + if self.profiler: + with self.profiler.profile(name): + yield + else: + yield + + def _fetch_next_batch(self): + with self.apply_profiler(f"get_{self.stage}_batch"): + with self.fetching_context(): + data = self.on_fetch_start() + with self.apply_profiler(f"fetch_next_{self.stage}_batch"): + batch = next(self.dataloader_iter) + self.fetched += 1 + self.on_fetch_end(batch, data) + + def _consume_prefetched_batches(self) -> Generator: + self.done = True + while self.batches: + yield from self._yield_batch() + + def _yield_batch(self) -> Generator: + self.wait() + batch = self.batches.pop(0) + if not self.store_on_gpu: + batch = self.move_data_to_device(batch) + is_last = len(self.batches) == 0 + yield batch, is_last + + def move_data_to_device(self, batch: Any) -> Any: + if self.batch_to_device: + with self.apply_profiler(f"move_{self.stage}_batch_to_device"): + batch = self.batch_to_device(batch) + return batch + + +class InterBatchParallelismDataFetcher(DataFetcher): + + """ + This class implements `inter-batch-parallelism` algorithm which aims at hiding the latency of host-to-device copy + of input batches behind computational intensive operation. + Without parallization: + batch 0: [HtoD][forward][backward] + batch 1: [HtoD][forward][backward] + batch 2: [HtoD][forward][backward] + With parallelization, the latency of HtoD copy can be hidden: + batch 0: [HtoD][forward][backward] + batch 1: [HtoD] [forward][backward] + batch 2: [HtoD] [forward][backward] + """ + + def __init__( + self, + prefetch_batches: int = 0, + ) -> None: + super().__init__(prefetch_batches=prefetch_batches, store_on_gpu=True) + + self.cuda_stream = torch.cuda.Stream() + self.events: List[torch.cuda.Event] = [] + + @contextmanager + def fetching_context(self): + """Wrap the batch fetching logic under a cuda stream""" + with torch.cuda.stream(self.cuda_stream): + yield + + def on_fetch_start(self) -> "torch.cuda.Event": + # create a cuda event used to record the async stream of data to device. + return torch.cuda.Event() + + def on_fetch_end(self, batch, event: torch.cuda.Event) -> None: + # move the batch to device and store it + super().on_fetch_end(batch) + + # record event and store the event + event.record() + self.events.append(event) + + def wait(self) -> None: + # pop first event from the queue and wait for the batch to be available on device. + event = self.events.pop(0) + event.wait() + + +class StepFuncDataLoaderIter: + + """ + This class is a wrapper to keep track of dataloader iterator fetching event + while left entirely to user control. + """ + + def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"): + self.iterator = iterator + self.data_fetcher = data_fetcher + + def __iter__(self) -> "StepFuncDataLoaderIter": + return self + + def __next__(self) -> Any: + try: + data = next(self.iterator) + # FIXME: Link this to `batch_idx`. + self.data_fetcher.fetched += 1 + return data + except StopIteration: + self.data_fetcher.done = True + raise StopIteration + + +class DataLoaderIterDataFetcher(AbstractDataFetcher): + + """ + This class is used to return directly the `dataloader_iter` to the ``LightningModule`` training_step + for users to implement their own pre-fetching logic. + This feature can be activated as follow: + Example:: + Class MyModel(LightningModule): + def __init__(self): + self.automatic_optimization = False + def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: + # it is the user responsability to fetch and move the batch to the right device. + batch = next(dataloader_iter) + batch = batch.to(self.device) + ... + """ + + def fetching_function(self) -> Generator: + iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) + while not self.done: + yield iterator, self.fetched, self.done diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 35b309549fb7e..af715668c6ee8 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -11,16 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from time import time +from typing import Any + import pytest +import torch from torch import tensor -from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import DataLoader, Dataset, IterableDataset +from pytorch_lightning import Trainer from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataFetcher +from pytorch_lightning.utilities.fetching import ( + DataFetcher, + DataLoaderIterDataFetcher, + InterBatchParallelismDataFetcher, +) +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf -@pytest.mark.parametrize("use_combined_loader", [False, True]) +@pytest.mark.parametrize("use_combined_loader", [False]) def test_prefetch_iterator(use_combined_loader): """Test the DataFetcher with PyTorch IterableDataset.""" @@ -84,3 +95,148 @@ def test_misconfiguration_error(): iter(fetcher) assert fetcher.loader_iters + + +def get_cycles_per_ms() -> float: + """ + Measure and return approximate number of cycles per millisecond for torch.cuda._sleep + Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py + """ + + def measure() -> float: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + return cycles_per_ms + + # Get 10 values and remove the 2 max and 2 min and return the avg. + # This is to avoid system disturbance that skew the results, e.g. + # the very first cuda call likely does a bunch of init, which takes + # much longer than subsequent calls. + # + # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs + # and seems to return stable values. Therefore, we enable caching + # using lru_cache decorator above. + num = 10 + vals = [] + for _ in range(num): + vals.append(measure()) + vals = sorted(vals) + stats = vals[2 : num - 2] + return sum(stats) / len(stats) + + +BATCH_SIZE = 128 +EMB_SZ = 100 +EMB_DIM = 64 + + +class RandomIndicesDataset(Dataset): + def __getitem__(self, index): + return torch.randint(EMB_DIM, [BATCH_SIZE]) + + def __len__(self): + return 16 + + +class RecommenderModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = None + self.local_embedding = torch.nn.Embedding(EMB_SZ, EMB_DIM) + self.CYCLES_PER_MS = int(get_cycles_per_ms()) + + def forward(self, indices: torch.Tensor): + result = self.local_embedding(indices) + return result + + def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + # emulate heavy routine + torch.cuda._sleep(self.CYCLES_PER_MS * 100) + return batch + + def training_step_end(self, training_step_outputs): + # emulate heavy routine + torch.cuda._sleep(self.CYCLES_PER_MS * 100) + return training_step_outputs + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + + def train_dataloader(self): + return DataLoader(RandomIndicesDataset(), batch_size=4) + + def val_dataloader(self): + return DataLoader(RandomIndicesDataset(), batch_size=4) + + def test_dataloader(self): + return DataLoader(RandomIndicesDataset(), batch_size=4) + + +@RunIf(min_gpus=1, min_torch="1.8.0") +def test_trainer_num_prefetch_batches(tmpdir): + + model = RecommenderModel() + + t0 = time() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, gpus=1) + trainer.data_connector.data_fetcher = InterBatchParallelismDataFetcher() + trainer.fit(model) + t1 = time() + global_step = trainer.global_step + assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelismDataFetcher) + + torch.cuda.synchronize() + + t2 = time() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, gpus=1) + trainer.fit(model) + t3 = time() + + assert global_step == trainer.global_step == 8 + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) + ratio = (t3 - t2) / (t1 - t0) + assert ratio > 1.2, ratio + + +@pytest.mark.parametrize("automatic_optimization", [False, True]) +@RunIf(min_torch="1.8.0") +def test_fetching_dataloader_iter(automatic_optimization, tmpdir): + class TestModel(BoringModel): + def __init__(self, *args, automatic_optimization: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.automatic_optimization = automatic_optimization + self.count = 0 + self.batches = [] + + def training_step(self, dataloader_iter, batch_idx): + assert self.count == batch_idx + assert isinstance(self.trainer.data_connector.data_fetcher, DataLoaderIterDataFetcher) + # fetch 2 batches + self.batches.append(next(dataloader_iter)) + self.batches.append(next(dataloader_iter)) + + batch = self.batches.pop(0) + assert isinstance(batch, torch.Tensor) or batch is None + self.count += 2 + if self.automatic_optimization: + return super().training_step(batch, 0) + else: + opt = self.optimizers() + output = self(batch) + loss = self.loss(batch, output) + opt.zero_grad() + loss.backward() + opt.step() + + training_epoch_end = None + + model = TestModel(automatic_optimization=automatic_optimization) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.data_connector.data_fetcher = DataLoaderIterDataFetcher() + trainer.fit(model) + assert model.count == 64 From 60df25a6cf7c4b0143dca6197d08fa6e363cf3fb Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 20 Aug 2021 18:52:30 +0100 Subject: [PATCH 02/74] resolve tests --- .../loops/epoch/training_epoch_loop.py | 1 + .../processors/iterator_batch_processor.py | 31 +++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 3 ++ tests/loops/test_iterator_batch_processor.py | 12 ------- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 09909eaa5e30a..beb22da475941 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -158,6 +158,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = [opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out)] + processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 7e12602aa9a2c..4f58ee8523e45 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -14,7 +14,7 @@ import logging from collections import OrderedDict from copy import copy -from typing import Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -29,6 +29,7 @@ from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) @@ -113,6 +114,29 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i """ return list(enumerate(self.trainer.optimizers)) + def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]: + """Builds the keyword arguments for training_step + + Args: + batch: the batch to train on + batch_idx: the index of the current batch + opt_idx: the index of the current optimizer + hiddens: the hidden state of the previous RNN iteration + + Returns: + the keyword arguments for the training step + """ + # enable not needing to add opt_idx to training_step + step_kwargs = OrderedDict({"dataloader_iter": dataloader_iter}) + + lightning_module = self.trainer.lightning_module + + training_step_fx = getattr(lightning_module, "training_step") + if is_param_in_hook_signature(training_step_fx, "batch_idx"): + step_kwargs["batch_idx"] = batch_idx + + return step_kwargs + def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: """ Args: @@ -138,7 +162,7 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: # manually capture logged metrics model._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): - step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter), ("batch_idx", batch_idx)]) + step_kwargs = self._build_kwargs(dataloader_iter, batch_idx) training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() @@ -152,7 +176,8 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - batch_outputs[0].append(copy(training_step_output)) + if training_step_output: + batch_outputs[0].append(copy(training_step_output)) return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last) def teardown(self) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 887cdd46a9db2..7e84854528059 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -77,6 +77,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_summary import ModelSummary, summarize @@ -924,6 +925,8 @@ def _maybe_switch_to_iterator_batch_processor(self, model: "pl.LightningModule") ) batch_loop = IteratorBatchProcessor(self, model) self.fit_loop.epoch_loop.connect(batch_loop) + # FIXME: Move this logic to data_connector after removing `IteratorBatchProcessor` + self.data_connector.data_fetcher = DataLoaderIterDataFetcher() def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py index 7708ce483901a..2cd6a172f6941 100644 --- a/tests/loops/test_iterator_batch_processor.py +++ b/tests/loops/test_iterator_batch_processor.py @@ -106,18 +106,6 @@ def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: ), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed." -def test_automatic_optimization_enabled(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `automatic_optimization` is enabled. - """ - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = AsyncBoringModel() - m.automatic_optimization = True - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - def test_on_train_batch_start_overridden(tmpdir) -> None: """ Verify that a `MisconfigurationException` is raised when From ea3311e894d94ce4ddd734cefdcbc956d4e35318 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 20 Aug 2021 14:06:32 -0400 Subject: [PATCH 03/74] update --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 6 ++++-- pytorch_lightning/trainer/connectors/data_connector.py | 10 +++++++--- tests/utilities/test_fetching.py | 4 ++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index beb22da475941..bb8680d7d4fd8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -135,8 +135,10 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ - with self.trainer.profiler.profile("training_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch) + # FIXME: Remove with InterBatchProcessor. + if not self.trainer.data_connector.data_fetcher.store_on_gpu: + with self.trainer.profiler.profile("training_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 089305dca7c33..7ab88d3c28801 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -13,11 +13,10 @@ # limitations under the License. from typing import Callable, Iterable, Optional, Union - import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataFetcher +from pytorch_lightning.utilities.fetching import DataFetcher, InterBatchParallelismDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -61,7 +60,12 @@ def on_trainer_init( self.trainer._is_data_prepared = False def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: - self.data_fetcher.setup(train_dataloader) + # FIXME: Temporary hack + if isinstance(self.data_fetcher, InterBatchParallelismDataFetcher): + self.data_fetcher.setup( + train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device) + else: + self.data_fetcher.setup(train_dataloader) prefetcher_iter = iter(self.data_fetcher) profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch") return profiled_dl diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index af715668c6ee8..9082cd124397f 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -188,7 +188,7 @@ def test_trainer_num_prefetch_batches(tmpdir): trainer.fit(model) t1 = time() global_step = trainer.global_step - assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelismDataFetcher) + assert isinstance(trainer.data_connector.data_fetcher, InterBatchParallelismDataFetcher) torch.cuda.synchronize() @@ -198,7 +198,7 @@ def test_trainer_num_prefetch_batches(tmpdir): t3 = time() assert global_step == trainer.global_step == 8 - assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) + assert isinstance(trainer.data_connector.data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) assert ratio > 1.2, ratio From 862349e1ce95117f2b68ebcb042fe2f879f056f8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Aug 2021 18:08:00 +0000 Subject: [PATCH 04/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/data_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 7ab88d3c28801..1bec7fd0700e5 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Callable, Iterable, Optional, Union + import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -60,10 +61,9 @@ def on_trainer_init( self.trainer._is_data_prepared = False def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: - # FIXME: Temporary hack + # FIXME: Temporary hack if isinstance(self.data_fetcher, InterBatchParallelismDataFetcher): - self.data_fetcher.setup( - train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device) + self.data_fetcher.setup(train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device) else: self.data_fetcher.setup(train_dataloader) prefetcher_iter = iter(self.data_fetcher) From d55185992a73a2a88a1e831d3e8ca03a22b641fc Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 20 Aug 2021 19:10:02 +0100 Subject: [PATCH 05/74] update --- .../loops/processors/iterator_batch_processor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 4f58ee8523e45..51374d22b9240 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -75,11 +75,6 @@ def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: "The model hook `tbptt_split_batch` is not compatible with " "taking a `dataloader_iter` argument in your `training_step`." ) - # if model.automatic_optimization: - # raise MisconfigurationException( - # "`automatic_optimization` is not compatible with " - # "taking a `dataloader_iter` argument in your `training_step`." - # ) if trainer.accumulate_grad_batches != 1: raise MisconfigurationException( "`accumulate_grad_batches` can only be 1 when your " From c190af555b55c4bc060531ee1c7eb6ef69c4de9f Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 20 Aug 2021 19:11:06 +0100 Subject: [PATCH 06/74] update --- pytorch_lightning/loops/processors/iterator_batch_processor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 51374d22b9240..2671c65faf300 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -115,8 +115,6 @@ def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Args: batch: the batch to train on batch_idx: the index of the current batch - opt_idx: the index of the current optimizer - hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step From 7d67e42c91110a3c90065f0a6b1af5e8aa3d3a32 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 20 Aug 2021 19:11:29 +0100 Subject: [PATCH 07/74] update --- .../loops/processors/iterator_batch_processor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 2671c65faf300..85cf43fcf836a 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -135,10 +135,6 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: Args: dataloader_iter: the iterator over the dataloader producing the new batch """ - # dataloader_iter = itertools.starmap( - # lambda batch_idx, batch_with_is_last: batch_with_is_last[0], dataloader_iter - # ) - _, (dataloader_iter, batch_idx, is_last) = next(dataloader_iter) self.trainer.logger_connector.on_batch_start() From fadeddc150a6db51e67c8f5f0a1819ce7aea110e Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 20 Aug 2021 19:14:48 +0100 Subject: [PATCH 08/74] update --- CHANGELOG.md | 6 ++++++ .../loops/processors/iterator_batch_processor.py | 3 +-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e3d9382a4271..495a56c28b95d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974)) +- Added `InterBatchParallelismDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) + + +- Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 85cf43fcf836a..004bdb4fb5cf3 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -13,7 +13,6 @@ # limitations under the License. import logging from collections import OrderedDict -from copy import copy from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -166,7 +165,7 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] if training_step_output: - batch_outputs[0].append(copy(training_step_output)) + batch_outputs[0].append(training_step_output) return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last) def teardown(self) -> None: From 53a32c127f77a0bad4a71d3ab798a61951fa914c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 23 Aug 2021 10:50:28 +0100 Subject: [PATCH 09/74] update on comments --- pytorch_lightning/utilities/fetching.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 395e941fce653..f1b4e40bfe420 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -38,7 +38,7 @@ class AbstractDataFetcher(ABC): """ - This based class should be used to implement a fault tolerant `DataFetcher`. + This based class should be used to implement a fault tolerant ``DataFetcher``. It is required to override the ``fetching_function`` with fetching logic. Example:: class SimpleDataFetcher(AbstractDataFetcher): @@ -206,7 +206,8 @@ class DataFetcher(AbstractDataFetcher): """ This class is used to control batch fetching flow. - By default, the `fetching_function` will `prefetch` a batch in advance to detect the end of the iteration. + By default, the ``fetching_function`` will pre-fetch a batch in advance to detect the end of the iteration. + Args: prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch at least 1 batch for tracking the latest batch. @@ -305,13 +306,17 @@ def move_data_to_device(self, batch: Any) -> Any: class InterBatchParallelismDataFetcher(DataFetcher): """ - This class implements `inter-batch-parallelism` algorithm which aims at hiding the latency of host-to-device copy - of input batches behind computational intensive operation. - Without parallization: + This class implements inter-batch parallelism, which aims at hiding the latency of host-to-device copy + of input batches behind computationally intensive operations. + + Without parallelization: + batch 0: [HtoD][forward][backward] batch 1: [HtoD][forward][backward] batch 2: [HtoD][forward][backward] + With parallelization, the latency of HtoD copy can be hidden: + batch 0: [HtoD][forward][backward] batch 1: [HtoD] [forward][backward] batch 2: [HtoD] [forward][backward] @@ -380,7 +385,8 @@ class DataLoaderIterDataFetcher(AbstractDataFetcher): """ This class is used to return directly the `dataloader_iter` to the ``LightningModule`` training_step for users to implement their own pre-fetching logic. - This feature can be activated as follow: + This feature can be activated as follows: + Example:: Class MyModel(LightningModule): def __init__(self): From 1bc068f689c0325ef2f69c1abe7765baa7ab7cb3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 23 Aug 2021 11:07:21 +0100 Subject: [PATCH 10/74] update --- .../loops/dataloader/evaluation_loop.py | 14 +++-- pytorch_lightning/loops/fit_loop.py | 12 ++-- .../trainer/connectors/data_connector.py | 56 +++++++++++++++---- 3 files changed, 61 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 52998676e1923..0a221b66fa5b4 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Iterator, List, Optional, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader @@ -20,7 +20,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -98,10 +97,8 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - data_fetcher = DataFetcher() - data_fetcher.setup(dataloader) - dataloader_iter = enumerate(data_fetcher) + + dataloader_iter = self._prepare_dataloader_iter() dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.epoch_loop.run( @@ -250,3 +247,8 @@ def on_evaluation_epoch_end(self) -> None: def teardown(self) -> None: self._results.cpu() self.epoch_loop.teardown() + + def _prepare_dataloader_iter(self) -> Iterator: + dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) + return iter(dataloader) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b77f186453c6a..830d89d5b32e6 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Optional +from typing import Iterator, Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop @@ -192,12 +192,11 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" - train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) + dataloader_iter = self._prepare_dataloader_iter() with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - epoch_output = self.epoch_loop.run(train_dataloader) + epoch_output = self.epoch_loop.run(dataloader_iter) if epoch_output is None: return @@ -234,3 +233,8 @@ def should_accumulate(self) -> bool: def teardown(self) -> None: self.epoch_loop.teardown() + + def _prepare_dataloader_iter(self) -> Iterator: + dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) + return iter(dataloader) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 2704687069b7f..d05752c243361 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -11,22 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +from functools import partial from typing import Callable, Iterable, Optional, Union import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataFetcher, InterBatchParallelismDataFetcher +from pytorch_lightning.utilities.fetching import ( + AbstractDataFetcher, + DataFetcher, + DataLoaderIterDataFetcher, + InterBatchParallelismDataFetcher, +) from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from pytorch_lightning.utilities.warnings import rank_zero_warn class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - self.data_fetcher: Optional[DataFetcher] = DataFetcher() + + self.train_data_fetcher: Optional[AbstractDataFetcher] = None + self.validate_data_fetcher: Optional[AbstractDataFetcher] = None + self.test_data_fetcher: Optional[AbstractDataFetcher] = None def on_trainer_init( self, @@ -60,15 +71,38 @@ def on_trainer_init( self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False - def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: - # FIXME: Temporary hack - if isinstance(self.data_fetcher, InterBatchParallelismDataFetcher): - self.data_fetcher.setup(train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device) + def _check_training_step_requires_dataloader_iter(self) -> bool: + if not self.trainer.training: + return False + training_step_fx = getattr(self.trainer.lightning_module, "training_step") + contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) + return contains_dataloader_iter + + def _select_data_fetcher(self) -> AbstractDataFetcher: + if self._check_training_step_requires_dataloader_iter(): + rank_zero_warn( + "Found `dataloader_iter` argument in the `training_step`. Note that the support for " + "this signature is experimental and the behavior may subject to change." + ) + return DataLoaderIterDataFetcher() + elif self.trainer.training_type_plugin.on_gpu and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": + # note: this is an experimental feature + return InterBatchParallelismDataFetcher() else: - self.data_fetcher.setup(train_dataloader) - prefetcher_iter = iter(self.data_fetcher) - profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch") - return profiled_dl + return DataFetcher() + + def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: + stage: str = self.trainer.state.stage.value + data_fetcher = self._select_data_fetcher() + data_fetcher.setup( + dataloader, + stage=stage, + batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), + profiler=self.trainer.profiler, + ) + # store to enable teardown and clean extra fetched batches + setattr(self, f"{stage}_data_fetcher", data_fetcher) + return enumerate(data_fetcher) def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 From b094b34d4aebc86a60bfa09bda230871a4f9ec84 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 23 Aug 2021 11:09:23 +0100 Subject: [PATCH 11/74] update on comments --- pytorch_lightning/utilities/fetching.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index f1b4e40bfe420..75bd8f3daefa6 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -38,9 +38,11 @@ class AbstractDataFetcher(ABC): """ - This based class should be used to implement a fault tolerant ``DataFetcher``. + This base class should be used to implement a fault tolerant ``DataFetcher``. It is required to override the ``fetching_function`` with fetching logic. + Example:: + class SimpleDataFetcher(AbstractDataFetcher): def fetching_function(self): while True: @@ -309,17 +311,19 @@ class InterBatchParallelismDataFetcher(DataFetcher): This class implements inter-batch parallelism, which aims at hiding the latency of host-to-device copy of input batches behind computationally intensive operations. - Without parallelization: + Example:: + + Without parallelization: - batch 0: [HtoD][forward][backward] - batch 1: [HtoD][forward][backward] - batch 2: [HtoD][forward][backward] + batch 0: [HtoD][forward][backward] + batch 1: [HtoD][forward][backward] + batch 2: [HtoD][forward][backward] - With parallelization, the latency of HtoD copy can be hidden: + With parallelization, the latency of HtoD copy can be hidden: - batch 0: [HtoD][forward][backward] - batch 1: [HtoD] [forward][backward] - batch 2: [HtoD] [forward][backward] + batch 0: [HtoD][forward][backward] + batch 1: [HtoD] [forward][backward] + batch 2: [HtoD] [forward][backward] """ def __init__( @@ -388,9 +392,12 @@ class DataLoaderIterDataFetcher(AbstractDataFetcher): This feature can be activated as follows: Example:: + Class MyModel(LightningModule): + def __init__(self): self.automatic_optimization = False + def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: # it is the user responsability to fetch and move the batch to the right device. batch = next(dataloader_iter) From 27a924d2a47a40a0f9b77c1e1c89d084b00b9d19 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 23 Aug 2021 11:10:04 +0100 Subject: [PATCH 12/74] typo --- pytorch_lightning/utilities/fetching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 75bd8f3daefa6..a7a9b4531398e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -49,7 +49,7 @@ def fetching_function(self): try: yield next(self.dataloader_iter), False except StopIteration: - yield None, True + return None, True """ @abstractmethod From 617333dd38fe62665806719be7880d016746b2a4 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 06:15:36 -0400 Subject: [PATCH 13/74] resolve bug --- .../loops/epoch/training_epoch_loop.py | 2 +- tests/utilities/test_fetching.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index bb8680d7d4fd8..78331b8d236ce 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -136,7 +136,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ # FIXME: Remove with InterBatchProcessor. - if not self.trainer.data_connector.data_fetcher.store_on_gpu: + if not self.trainer.data_connector.train_data_fetcher.store_on_gpu: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 9082cd124397f..4920599411751 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from unittest import mock from time import time from typing import Any @@ -182,13 +184,13 @@ def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() - t0 = time() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, gpus=1) - trainer.data_connector.data_fetcher = InterBatchParallelismDataFetcher() - trainer.fit(model) - t1 = time() - global_step = trainer.global_step - assert isinstance(trainer.data_connector.data_fetcher, InterBatchParallelismDataFetcher) + with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}): + t0 = time() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, gpus=1) + trainer.fit(model) + t1 = time() + global_step = trainer.global_step + assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelismDataFetcher) torch.cuda.synchronize() @@ -198,7 +200,7 @@ def test_trainer_num_prefetch_batches(tmpdir): t3 = time() assert global_step == trainer.global_step == 8 - assert isinstance(trainer.data_connector.data_fetcher, DataFetcher) + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) assert ratio > 1.2, ratio From 6eeba870a76eff4f60d13e8bd3b5471537079211 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 06:29:16 -0400 Subject: [PATCH 14/74] update --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 4 +--- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ---- tests/utilities/test_fetching.py | 4 +++- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index b03b15d820230..0f6a3569560d5 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -91,9 +91,6 @@ def advance( if batch is None: raise StopIteration - with self.trainer.profiler.profile("evaluation_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) - self.batch_progress.increment_ready() # hook @@ -229,3 +226,4 @@ def _track_output_for_epoch_end( output = output.cpu() outputs.append(output) return outputs + \ No newline at end of file diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 78331b8d236ce..49428ad257152 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -136,10 +136,6 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ # FIXME: Remove with InterBatchProcessor. - if not self.trainer.data_connector.train_data_fetcher.store_on_gpu: - with self.trainer.profiler.profile("training_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch) - self.batch_progress.increment_ready() with self.trainer.profiler.profile("run_training_batch"): diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 4920599411751..afb1f745ec799 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -191,6 +191,7 @@ def test_trainer_num_prefetch_batches(tmpdir): t1 = time() global_step = trainer.global_step assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelismDataFetcher) + assert isinstance(trainer.data_connector.validate_data_fetcher, InterBatchParallelismDataFetcher) torch.cuda.synchronize() @@ -201,8 +202,9 @@ def test_trainer_num_prefetch_batches(tmpdir): assert global_step == trainer.global_step == 8 assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) + assert isinstance(trainer.data_connector.validate_data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) - assert ratio > 1.2, ratio + assert ratio > 1.25, ratio @pytest.mark.parametrize("automatic_optimization", [False, True]) From 0326a33f7ebe6b164fd40d293f742f565508678b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 10:30:25 +0000 Subject: [PATCH 15/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 1 - tests/utilities/test_fetching.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 0f6a3569560d5..91680ae8dc2a7 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -226,4 +226,3 @@ def _track_output_for_epoch_end( output = output.cpu() outputs.append(output) return outputs - \ No newline at end of file diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index afb1f745ec799..34814b6e79660 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from unittest import mock from time import time from typing import Any +from unittest import mock import pytest import torch From 06dae29a233a3bb75f6a5d4e6b1be4776d72e357 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 06:44:43 -0400 Subject: [PATCH 16/74] update on comments --- CHANGELOG.md | 2 +- .../loops/epoch/training_epoch_loop.py | 3 +- .../processors/iterator_batch_processor.py | 43 ++++++++++--------- .../trainer/connectors/data_connector.py | 6 +-- pytorch_lightning/utilities/fetching.py | 25 ++++++----- tests/utilities/test_fetching.py | 20 +++++---- 6 files changed, 53 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39cdfac0cd333..61f8e9e4284b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,7 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974)) -- Added `InterBatchParallelismDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) +- Added `InterBatchParallelDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) - Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index bb8680d7d4fd8..7242ccde42632 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -136,7 +136,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ # FIXME: Remove with InterBatchProcessor. - if not self.trainer.data_connector.data_fetcher.store_on_gpu: + if not self.trainer.data_connector.data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch) @@ -160,7 +160,6 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = [opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out)] - processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 004bdb4fb5cf3..035a8ebf6b5fb 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -108,27 +108,6 @@ def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[i """ return list(enumerate(self.trainer.optimizers)) - def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]: - """Builds the keyword arguments for training_step - - Args: - batch: the batch to train on - batch_idx: the index of the current batch - - Returns: - the keyword arguments for the training step - """ - # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict({"dataloader_iter": dataloader_iter}) - - lightning_module = self.trainer.lightning_module - - training_step_fx = getattr(lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx"): - step_kwargs["batch_idx"] = batch_idx - - return step_kwargs - def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: """ Args: @@ -173,3 +152,25 @@ def teardown(self) -> None: No-op. Only defined to comply with FitLoop's expectation. """ pass + + # FIXME: To be deleted in next PR. + def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]: + """Builds the keyword arguments for training_step + + Args: + batch: the batch to train on + batch_idx: the index of the current batch + + Returns: + the keyword arguments for the training step + """ + # enable not needing to add opt_idx to training_step + step_kwargs = OrderedDict({"dataloader_iter": dataloader_iter}) + + lightning_module = self.trainer.lightning_module + + training_step_fx = getattr(lightning_module, "training_step") + if is_param_in_hook_signature(training_step_fx, "batch_idx"): + step_kwargs["batch_idx"] = batch_idx + + return step_kwargs \ No newline at end of file diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 2704687069b7f..46af20c3ad751 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -17,7 +17,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataFetcher, InterBatchParallelismDataFetcher +from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, InterBatchParallelDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -26,7 +26,7 @@ class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - self.data_fetcher: Optional[DataFetcher] = DataFetcher() + self.data_fetcher: AbstractDataFetcher = DataFetcher() def on_trainer_init( self, @@ -62,7 +62,7 @@ def on_trainer_init( def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: # FIXME: Temporary hack - if isinstance(self.data_fetcher, InterBatchParallelismDataFetcher): + if isinstance(self.data_fetcher, InterBatchParallelDataFetcher): self.data_fetcher.setup(train_dataloader, batch_to_device=self.trainer.accelerator.batch_to_device) else: self.data_fetcher.setup(train_dataloader) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index a7a9b4531398e..1ea323f201055 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -213,16 +213,16 @@ class DataFetcher(AbstractDataFetcher): Args: prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch at least 1 batch for tracking the latest batch. - store_on_gpu: Whether to store the pre-fetched batches on device. + store_on_device: Whether to store the pre-fetched batches on device. """ def __init__( self, prefetch_batches: int = 0, - store_on_gpu: bool = False, + store_on_device: bool = False, ) -> None: super().__init__(prefetch_batches=prefetch_batches) - self.store_on_gpu = store_on_gpu + self.store_on_device = store_on_device @contextmanager def fetching_context(self): @@ -234,7 +234,7 @@ def on_fetch_start(self) -> None: def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> None: """Hook to extend which handles the logic after fetching a batch""" - if self.store_on_gpu: + if self.store_on_device: batch = self.move_data_to_device(batch) self.append_batch(batch) @@ -251,10 +251,15 @@ def fetching_function(self) -> Generator: yield_batch = self.pop_batch() self._fetch_next_batch() - # yield last and has next + # wait for batch to be available. self.wait() - yield (self.move_data_to_device(yield_batch) if not self.store_on_gpu else yield_batch, False) + # yield last and has next + yield ( + self.move_data_to_device(yield_batch) + if not self.store_on_device else yield_batch, + False + ) except StopIteration: self.batches.insert(0, yield_batch) break @@ -293,7 +298,7 @@ def _consume_prefetched_batches(self) -> Generator: def _yield_batch(self) -> Generator: self.wait() batch = self.batches.pop(0) - if not self.store_on_gpu: + if not self.store_on_device: batch = self.move_data_to_device(batch) is_last = len(self.batches) == 0 yield batch, is_last @@ -305,13 +310,13 @@ def move_data_to_device(self, batch: Any) -> Any: return batch -class InterBatchParallelismDataFetcher(DataFetcher): +class InterBatchParallelDataFetcher(DataFetcher): """ This class implements inter-batch parallelism, which aims at hiding the latency of host-to-device copy of input batches behind computationally intensive operations. - Example:: + code-block:: Without parallelization: @@ -330,7 +335,7 @@ def __init__( self, prefetch_batches: int = 0, ) -> None: - super().__init__(prefetch_batches=prefetch_batches, store_on_gpu=True) + super().__init__(prefetch_batches=prefetch_batches, store_on_device=True) self.cuda_stream = torch.cuda.Stream() self.events: List[torch.cuda.Event] = [] diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 9082cd124397f..ecf808efcf549 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -25,7 +25,7 @@ from pytorch_lightning.utilities.fetching import ( DataFetcher, DataLoaderIterDataFetcher, - InterBatchParallelismDataFetcher, + InterBatchParallelDataFetcher, ) from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -156,12 +156,12 @@ def forward(self, indices: torch.Tensor): def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: # emulate heavy routine - torch.cuda._sleep(self.CYCLES_PER_MS * 100) + torch.cuda._sleep(self.CYCLES_PER_MS * 50) return batch def training_step_end(self, training_step_outputs): # emulate heavy routine - torch.cuda._sleep(self.CYCLES_PER_MS * 100) + torch.cuda._sleep(self.CYCLES_PER_MS * 50) return training_step_outputs def configure_optimizers(self): @@ -181,26 +181,28 @@ def test_dataloader(self): def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() + trainer_kwargs = dict( + default_root_dir=tmpdir, max_epochs=1, gpus=1, limit_train_batches=3, limit_val_batches=0) t0 = time() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, gpus=1) - trainer.data_connector.data_fetcher = InterBatchParallelismDataFetcher() + trainer = Trainer(**trainer_kwargs) + trainer.data_connector.data_fetcher = InterBatchParallelDataFetcher() trainer.fit(model) t1 = time() global_step = trainer.global_step - assert isinstance(trainer.data_connector.data_fetcher, InterBatchParallelismDataFetcher) + assert isinstance(trainer.data_connector.data_fetcher, InterBatchParallelDataFetcher) torch.cuda.synchronize() t2 = time() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, gpus=1) + trainer = Trainer(**trainer_kwargs) trainer.fit(model) t3 = time() - assert global_step == trainer.global_step == 8 + assert global_step == trainer.global_step == 3 assert isinstance(trainer.data_connector.data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) - assert ratio > 1.2, ratio + assert ratio > 1.1, ratio @pytest.mark.parametrize("automatic_optimization", [False, True]) From 9ea19531a504fb76b94293e50b3a6254ab22b832 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 10:45:47 +0000 Subject: [PATCH 17/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../loops/processors/iterator_batch_processor.py | 2 +- pytorch_lightning/utilities/fetching.py | 6 +----- tests/utilities/test_fetching.py | 9 ++------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 035a8ebf6b5fb..91962b6738fc3 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -173,4 +173,4 @@ def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, if is_param_in_hook_signature(training_step_fx, "batch_idx"): step_kwargs["batch_idx"] = batch_idx - return step_kwargs \ No newline at end of file + return step_kwargs diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 1ea323f201055..59442730245e0 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -255,11 +255,7 @@ def fetching_function(self) -> Generator: self.wait() # yield last and has next - yield ( - self.move_data_to_device(yield_batch) - if not self.store_on_device else yield_batch, - False - ) + yield (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False) except StopIteration: self.batches.insert(0, yield_batch) break diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index ecf808efcf549..71dbe9aee1c07 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -22,11 +22,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import ( - DataFetcher, - DataLoaderIterDataFetcher, - InterBatchParallelDataFetcher, -) +from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -181,8 +177,7 @@ def test_dataloader(self): def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() - trainer_kwargs = dict( - default_root_dir=tmpdir, max_epochs=1, gpus=1, limit_train_batches=3, limit_val_batches=0) + trainer_kwargs = dict(default_root_dir=tmpdir, max_epochs=1, gpus=1, limit_train_batches=3, limit_val_batches=0) t0 = time() trainer = Trainer(**trainer_kwargs) From 9bd8094e725b6f4fd331f88530dc578ee58f42ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 10:54:47 +0000 Subject: [PATCH 18/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_fetching.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 64d3488f9d192..300894522de94 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -180,7 +180,13 @@ def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() trainer_kwargs = dict( - default_root_dir=tmpdir, max_epochs=1, gpus=1, limit_train_batches=4, limit_val_batches=0, num_sanity_val_steps=0) + default_root_dir=tmpdir, + max_epochs=1, + gpus=1, + limit_train_batches=4, + limit_val_batches=0, + num_sanity_val_steps=0, + ) with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}): t0 = time() From e9d676037a52785d34c449169a9fa366b37125a1 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 06:59:33 -0400 Subject: [PATCH 19/74] update --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61f8e9e4284b4..26e3acb616f79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) +- Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047)) + + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) From 842a6ae86e2ca28d09c578f9f03bba75690972a2 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 08:33:44 -0400 Subject: [PATCH 20/74] update --- pytorch_lightning/core/lightning.py | 4 + pytorch_lightning/loops/__init__.py | 1 - .../loops/batch/training_batch_loop.py | 8 +- .../loops/epoch/training_epoch_loop.py | 47 ++--- .../loops/processors/__init__.py | 15 -- .../processors/iterator_batch_processor.py | 176 ---------------- .../trainer/connectors/data_connector.py | 38 +++- .../logger_connector/logger_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 17 +- pytorch_lightning/utilities/fetching.py | 5 +- tests/loops/test_inter_batch_parallelism.py | 190 ------------------ tests/loops/test_iterator_batch_processor.py | 183 ----------------- tests/utilities/test_fetching.py | 155 +++++++++++++- 13 files changed, 210 insertions(+), 633 deletions(-) delete mode 100644 pytorch_lightning/loops/processors/__init__.py delete mode 100644 pytorch_lightning/loops/processors/iterator_batch_processor.py delete mode 100644 tests/loops/test_inter_batch_parallelism.py delete mode 100644 tests/loops/test_iterator_batch_processor.py diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c847eea57ccd0..ac1194615da1e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -454,6 +454,10 @@ def log( f" of {list(self._metric_attributes.values())}" ) + if is_param_in_hook_signature(self.training_step, "dataloader_iter") and batch_size is None: + raise MisconfigurationException( + "When the `dataloader_iter` is requested within the `training_step`, `batch_size` should be provided.") + results.log( self._current_fx_name, name, diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index c9775ed44155e..b7eb47167d26f 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -17,4 +17,3 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401 from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 -from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401 diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index a048bc0c3a91c..20eb5e61ec7b3 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -545,12 +545,16 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) + step_kwargs = OrderedDict({"batch": batch}) lightning_module = self.trainer.lightning_module + training_step_fx = getattr(self.trainer.lightning_module, "training_step") + + if is_param_in_hook_signature(training_step_fx, "batch_idx"): + step_kwargs["batch_idx"] = batch_idx + if len(self.trainer.optimizers) > 1: - training_step_fx = getattr(lightning_module, "training_step") has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 6e5841c1e028b..8368fc6eeaaf8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -17,18 +17,12 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop -from pytorch_lightning.loops.processors import IteratorBatchProcessor from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT -# TODO: currently, the batch processor is only a loop when tbptt is enabled. -# As we introduce more specialized batch processors, we may want to choose a -# more suitable abstraction for them. -BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]] - class TrainingEpochLoop(loops.Loop): """ @@ -50,7 +44,7 @@ def __init__(self, min_steps: int, max_steps: int): self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() - self.batch_loop: BATCH_LOOP_TYPE = None + self.batch_loop: TrainingBatchLoop = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) @@ -81,7 +75,7 @@ def done(self) -> bool: def connect( self, - batch_loop: BATCH_LOOP_TYPE = None, + batch_loop: TrainingBatchLoop = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" @@ -118,29 +112,17 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: Raises: StopIteration: When the epoch is canceled by the user returning -1 """ - if isinstance(self.batch_loop, IteratorBatchProcessor): - # By contract, when taking `dataloader_iter` as an argument, - # `training_step` is responsible for reporting `is_last` in the - # result dict, which is used to determine the stop condition for - # the epoch. So as long as `advance` is invoked, it's correct to - # assume that there are more batches to be processed. - self.batch_progress.increment_ready() - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(dataloader_iter) - self.batch_progress.increment_processed() - is_last = batch_output.is_last - else: - _, (batch, is_last) = next(dataloader_iter) - - # ------------------------------------ - # TRAINING_STEP + TRAINING_STEP_END - # ------------------------------------ - self.batch_progress.increment_ready() - - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, self.batch_idx) - - self.batch_progress.increment_processed() + batch_idx, (batch, is_last) = next(dataloader_iter) + + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ + self.batch_progress.increment_ready() + + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.batch_loop.run(batch, batch_idx) + + self.batch_progress.increment_processed() self.is_last_batch = is_last @@ -158,8 +140,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook - if not isinstance(self.batch_loop, IteratorBatchProcessor): - self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0) + self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0) self.trainer.call_hook("on_batch_end") self.trainer.logger_connector.on_batch_end() diff --git a/pytorch_lightning/loops/processors/__init__.py b/pytorch_lightning/loops/processors/__init__.py deleted file mode 100644 index 9fcbe9e82dca8..0000000000000 --- a/pytorch_lightning/loops/processors/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pytorch_lightning.loops.processors.iterator_batch_processor import IteratorBatchProcessor # noqa: F401 diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py deleted file mode 100644 index 91962b6738fc3..0000000000000 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Tuple - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.loops.utilities import ( - _check_training_step_output, - _process_training_step_output, - check_finite_loss, -) -from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature - -log = logging.getLogger(__name__) - - -class IteratorBatchProcessor: - """ - The processor for performing a training iteration when ``training_step`` needs access to the - dataloader. It is selected when the signature of ``training_step`` contains ``dataloader_iter``: - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - - The ``training_step`` is allowed to fetch multiple batches during one training iteration. The - framework provides minimum amount of automation with regards to model optimization. The - flexibility allows for ease of experimentation with inter-batch parallelism techniques. - - This processor doesn't support ``automatic_optimization`` and ``tbptt``. An error will be thrown - if the ``LightningModule`` or the ``Trainer`` is configured to use these features. - - The ``training_step`` is responsible for reporting whether it has reached the last batch by - including an ``is_last`` field in the result dict. Failing to do so will result in an error. - - The ``training_step`` should only optimize the model with one batch for the sake of API and - reporting consistency (TODO: consider removing this limitation). - - Args: - trainer: a reference to the trainer - model: a reference to the lightning module (for config validation purposes only) - """ - - def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: - if is_overridden("on_train_batch_start", model): - raise MisconfigurationException( - "The model hook `on_train_batch_start` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("on_train_batch_end", model): - raise MisconfigurationException( - "The model hook `on_train_batch_end` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("tbptt_split_batch", model): - raise MisconfigurationException( - "The model hook `tbptt_split_batch` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if trainer.accumulate_grad_batches != 1: - raise MisconfigurationException( - "`accumulate_grad_batches` can only be 1 when your " - "`training_step` takes `dataloader_iter` as an argument." - ) - - self.trainer = trainer - - # The following field is not used by the processor since it doesn't support automatic - # optimization and tbptt. Initializing them regardless since they are currently expected by - # `FitLoop` or `TrainingEpochLoop`. - # TODO: come up with an abstraction for "batch processors" so they can be better decoupled - # with parent loops. - self.accumulated_loss: Optional[torch.Tensor] = None - self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1) - self.optim_progress = OptimizationProgress() - self.split_idx: int = 0 - self._skip_backward = False - - def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: - """ - Returns the number of active optimizers. - """ - return len(self.trainer.optimizers) - - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]: - """ - Returns the currently active optimizers. - - Returns: - A list of tuples (opt_idx, optimizer) of currently active optimizers. - """ - return list(enumerate(self.trainer.optimizers)) - - def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: - """ - Args: - dataloader_iter: the iterator over the dataloader producing the new batch - """ - _, (dataloader_iter, batch_idx, is_last) = next(dataloader_iter) - - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - - # give the PL module a result for logging - model = self.trainer.lightning_module - - with self.trainer.profiler.profile("model_forward"): - # manually capture logged metrics - model._current_fx_name = "training_step" - with self.trainer.profiler.profile("training_step"): - step_kwargs = self._build_kwargs(dataloader_iter, batch_idx) - training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(self.trainer.lightning_module, training_step_output) - - training_step_output, _ = _process_training_step_output(self.trainer, training_step_output) - - if self.trainer.terminate_on_nan: - check_finite_loss(self.trainer.lightning_module, training_step_output.minimize) - - batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - - if training_step_output: - batch_outputs[0].append(training_step_output) - return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last) - - def teardown(self) -> None: - """ - No-op. Only defined to comply with FitLoop's expectation. - """ - pass - - # FIXME: To be deleted in next PR. - def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]: - """Builds the keyword arguments for training_step - - Args: - batch: the batch to train on - batch_idx: the index of the current batch - - Returns: - the keyword arguments for the training step - """ - # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict({"dataloader_iter": dataloader_iter}) - - lightning_module = self.trainer.lightning_module - - training_step_fx = getattr(lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx"): - step_kwargs["batch_idx"] = batch_idx - - return step_kwargs diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index e7c731c8ab7b4..00105bf000fbd 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -72,13 +72,6 @@ def on_trainer_init( self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False - def _check_training_step_requires_dataloader_iter(self) -> bool: - if not self.trainer.training: - return False - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) - return contains_dataloader_iter - def _select_data_fetcher(self) -> AbstractDataFetcher: if self._check_training_step_requires_dataloader_iter(): rank_zero_warn( @@ -103,6 +96,8 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) ) # store to enable teardown and clean extra fetched batches setattr(self, f"{stage}_data_fetcher", data_fetcher) + if isinstance(data_fetcher, DataLoaderIterDataFetcher): + return data_fetcher return enumerate(data_fetcher) def prepare_data(self) -> None: @@ -207,6 +202,35 @@ def detach_data(model: "pl.LightningModule") -> None: if isinstance(loader, _PatchDataLoader): loader.unpatch(model) + def _check_training_step_requires_dataloader_iter(self) -> bool: + if not self.trainer.training: + return False + + training_step_fx = getattr(self.trainer.lightning_module, "training_step") + contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) + + if contains_dataloader_iter: + + if is_overridden("on_train_batch_start", self.trainer.lightning_module): + raise MisconfigurationException( + "The model hook `on_train_batch_start` is not compatible with " + "taking a `dataloader_iter` argument in your `training_step`." + ) + + if is_overridden("on_train_batch_end", self.trainer.lightning_module): + raise MisconfigurationException( + "The model hook `on_train_batch_end` is not compatible with " + "taking a `dataloader_iter` argument in your `training_step`." + ) + + if self.trainer.lightning_module.truncated_bptt_steps > 0: + raise MisconfigurationException( + "The model taking a `dataloader_iter` argument in your `training_step` " + "is incompatible with `truncated_bptt_steps > 0`." + ) + + return contains_dataloader_iter + class _PatchDataLoader: r""" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 11d07f7f6ffc2..5be0e9aae6f98 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -15,7 +15,6 @@ from typing import Any, Dict, Iterable, Mapping, Optional, Union import torch - import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource @@ -199,7 +198,8 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - self.trainer._results.extract_batch_size(split_batch) + if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): + self.trainer._results.extract_batch_size(split_batch) self._batch_idx = batch_idx self._split_idx = split_idx diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 75bd00f971784..45a7f88deab11 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,7 +28,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop @@ -916,18 +916,6 @@ def _load_checkpoint_weights(self): rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) - def _maybe_switch_to_iterator_batch_processor(self, model: "pl.LightningModule") -> None: - training_step_fx = getattr(model, "training_step") - if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): - log.warning( - "Found `dataloader_iter` argument in the `training_step`. Note that the support for " - "this signature is experimental and the behavior may subject to change." - ) - batch_loop = IteratorBatchProcessor(self, model) - self.fit_loop.epoch_loop.connect(batch_loop) - # FIXME: Move this logic to data_connector after removing `IteratorBatchProcessor` - self.data_connector.data_fetcher = DataLoaderIterDataFetcher() - def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): @@ -935,9 +923,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.config_validator.verify_loop_configurations(model) - if self.training: - self._maybe_switch_to_iterator_batch_processor(model) - # attach model log function to callback self.callback_connector.attach_model_logging_functions(model) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 59442730245e0..d7cf8e6740d0b 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -17,6 +17,7 @@ from contextlib import contextmanager from copy import deepcopy from functools import partial +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from typing import Any, Callable, Generator, List, Optional, Tuple import torch @@ -377,8 +378,6 @@ def __iter__(self) -> "StepFuncDataLoaderIter": def __next__(self) -> Any: try: data = next(self.iterator) - # FIXME: Link this to `batch_idx`. - self.data_fetcher.fetched += 1 return data except StopIteration: self.data_fetcher.done = True @@ -409,4 +408,4 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: def fetching_function(self) -> Generator: iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) while not self.done: - yield iterator, self.fetched, self.done + yield self.fetched, (iterator, self.done) diff --git a/tests/loops/test_inter_batch_parallelism.py b/tests/loops/test_inter_batch_parallelism.py deleted file mode 100644 index 00bc7049b0a29..0000000000000 --- a/tests/loops/test_inter_batch_parallelism.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time -from statistics import mean -from typing import Iterator - -import torch -from torch.utils.data import DataLoader, IterableDataset - -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers.runif import RunIf - - -def count_cycles_per_ms() -> float: - """ - Measure and return approximate number of cycles per millisecond for torch.cuda._sleep - - Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py - """ - - def measure() -> float: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - torch.cuda._sleep(1000000) - end.record() - end.synchronize() - cycles_per_ms = 1000000 / start.elapsed_time(end) - return cycles_per_ms - - # Get 10 values and remove the 2 max and 2 min and return the avg. - # This is to avoid system disturbance that skew the results, e.g. - # the very first cuda call likely does a bunch of init, which takes - # much longer than subsequent calls. - # - # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs - # and seems to return stable values. Therefore, we enable caching - # using lru_cache decorator above. - num = 10 - vals = [] - for _ in range(num): - vals.append(measure()) - vals = sorted(vals) - return mean(vals[2 : num - 2]) - - -_CYCLES_PER_MS = int(count_cycles_per_ms()) if torch.cuda.is_available() else 0 -_BATCH_SIZE = 128 -_EMB_SZ = 100 -_EMB_DIM = 64 - - -class RandomSparseDataset(IterableDataset): - def __init__(self, emb_dim: int, batch_size: int, count: int) -> None: - self.emb_dim = emb_dim - self.batch_size = batch_size - self.count = count - - def __iter__(self): - for _ in range(self.count): - yield torch.randint(self.emb_dim, [self.batch_size]) - - -class ToyDLRMModel(LightningModule): - """ - A toy model for mimicking the communication overhead of sharded embedding - modules in DLRM models. - - DLRM models can be trained in a DDP-like fashion, where each trainer - receives different batches (embedding indices in this example). Since the - embeddings are sharded across trainers, the lookup process involves (1) - routing the indices to the trainer that possesses the corresponding - embeddings (2) performing local lookup (3) routing the embedding lookup - result back. - - The toy model doesn't actually performs index/result routing. It simply - uses torch.cuda._sleep() to mimic the cost of the communication op (i.e. - a2a). - """ - - def __init__(self): - super().__init__() - self.automatic_optimization = False - self.local_embedding = torch.nn.Embedding(_EMB_SZ, _EMB_DIM) - - def _route_indices(self, batch: torch.Tensor, non_blocking=False): - """ - This can be parallelized across different batches since it's model - weight independent. - - Why not run this in dataloader/datamodule? - - The routing logic depends on how model is sharded - - Putting this in data preprocessor changes the semantic of the model - """ - torch.cuda._sleep(_CYCLES_PER_MS * 1_000) - if not non_blocking: - torch.cuda.synchronize() - return batch - - def _route_result(self, result: torch.Tensor, non_blocking=False): - torch.cuda._sleep(_CYCLES_PER_MS * 1_000) - if not non_blocking: - torch.cuda.synchronize() - return result - - def forward(self, indices: torch.Tensor): - local_indices = self._route_indices(indices) - result = self.local_embedding(local_indices) - return self._route_result(result) - - def training_step(self, batch: torch.Tensor, batch_idx: int) -> STEP_OUTPUT: - return self.forward(batch) - - def configure_optimizers(self): - return torch.optim.SGD(self.local_embedding.parameters(), lr=0.1) - - def train_dataloader(self): - return DataLoader(RandomSparseDataset(_EMB_DIM, _BATCH_SIZE, 5)) - - -class AsyncToyDLRMModel(ToyDLRMModel): - def __init__(self): - super().__init__() - self.comm_stream = torch.cuda.Stream() - self.batch_i = None - self.batch_i_ready = torch.cuda.Event() - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - if self.batch_i is None: - self.batch_i = next(dataloader_iter) - with torch.cuda.stream(self.comm_stream): - self._route_indices(self.batch_i, non_blocking=True) - self.batch_i_ready.record() - - # Invariant: the routing for batch[i] has been kicked off - is_last = False - batch_ip1 = None - batch_ip1_ready = torch.cuda.Event() - try: - batch_ip1 = next(dataloader_iter) - with torch.cuda.stream(self.comm_stream): - self._route_indices(batch_ip1, non_blocking=True) - batch_ip1_ready.record() - except StopIteration: - is_last = True - - self.batch_i_ready.wait() - - result = self.local_embedding(self.batch_i) - self._route_result(result) - - self.batch_i = batch_ip1 - self.batch_i_ready = batch_ip1_ready - - return {"is_last": is_last} - - -@RunIf(min_gpus=1) -def test_inter_batch_parallelism(tmpdir): - """ - Verify the speedup of a simple inter-batch parallelization use case enabled - by exposing `dataloader_iter` to `training_step`. - """ - begin_time = time.time() - m = AsyncToyDLRMModel() - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - trainer.fit(m) - async_duration = time.time() - begin_time - - begin_time = time.time() - m = ToyDLRMModel() - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - trainer.fit(m) - sync_duration = time.time() - begin_time - - # We expect 2x speedup. However, we only assert that the async - # training_step is faster in order to avoid flaky tests - assert async_duration < sync_duration, "Expect `AsyncToyDLRMModel` to train faster than `ToyDLRMModel`." diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py deleted file mode 100644 index 2cd6a172f6941..0000000000000 --- a/tests/loops/test_iterator_batch_processor.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterator - -import pytest -from torch.utils.data import DataLoader - -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers import BoringModel, RandomDataset - -_BATCH_SIZE = 32 -_DATASET_LEN = 64 - - -class DummyWaitable: - def __init__(self, val: Any) -> None: - self.val = val - - def wait(self) -> Any: - return self.val - - -class AsyncBoringModel(BoringModel): - def __init__(self) -> None: - super().__init__() - self.automatic_optimization = False - self.batch_i_handle = None - self.num_batches_processed = 0 - - def _async_op(self, batch: Any) -> DummyWaitable: - return DummyWaitable(val=batch) - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - if self.batch_i_handle is None: - batch_i_raw = next(dataloader_iter) - self.batch_i_handle = self._async_op(batch_i_raw) - - # Invariant: _async_op for batch[i] has been initiated - batch_ip1_handle = None - is_last = False - try: - batch_ip1_raw = next(dataloader_iter) - batch_ip1_handle = self._async_op(batch_ip1_raw) - except StopIteration: - is_last = True - - batch_i = self.batch_i_handle.wait() - - pred = self.layer(batch_i) - loss = self.loss(batch_i, pred) - loss.backward() - self.optimizers().step() - self.optimizers().zero_grad() - - self.batch_i_handle = batch_ip1_handle - self.num_batches_processed += 1 - - return {"loss": loss, "is_last": is_last} - - def train_dataloader(self): - return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN)) - - -def test_training_step_with_dataloader_access(tmpdir) -> None: - """ - A baseline functional test for `training_step` with dataloader access. - """ - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = AsyncBoringModel() - trainer.fit(m) - assert m.num_batches_processed == _DATASET_LEN, f"Expect all {_DATASET_LEN} batches to be processed." - - -def test_stop_iteration(tmpdir) -> None: - """ - Verify that when `StopIteration` is raised within `training_step`, `fit()` - terminiates as expected. - """ - EXPECT_NUM_BATCHES_PROCESSED = 2 - - class TestModel(AsyncBoringModel): - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED: - raise StopIteration() - return output - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = TestModel() - trainer.fit(m) - assert ( - m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED - ), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed." - - -def test_on_train_batch_start_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_start` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_on_train_batch_end_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_end` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_tbptt_split_batch_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `tbptt_split_batch` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def tbptt_split_batch(self, batch, split_size): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_accumulate_grad_batches(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `accumulate_grad_batches` is not set to 1. - """ - trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir) - m = AsyncBoringModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_is_last_not_set(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when `training_step` - doesn't include "is_last" in the result dict. - """ - - class InvalidModel(AsyncBoringModel): - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - del output["is_last"] - return output - - trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 300894522de94..b8944c137f118 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -15,7 +15,15 @@ from time import time from typing import Any from unittest import mock +from typing import Any, Iterator +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import STEP_OUTPUT +from tests.helpers import BoringModel, RandomDataset import pytest import torch from torch import tensor @@ -128,7 +136,8 @@ def measure() -> float: return sum(stats) / len(stats) -BATCH_SIZE = 128 +BATCH_SIZE = 32 +DATASET_LEN = 64 EMB_SZ = 100 EMB_DIM = 64 @@ -220,8 +229,8 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs): self.batches = [] def training_step(self, dataloader_iter, batch_idx): - assert self.count == batch_idx - assert isinstance(self.trainer.data_connector.data_fetcher, DataLoaderIterDataFetcher) + # assert self.count == batch_idx + assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher) # fetch 2 batches self.batches.append(next(dataloader_iter)) self.batches.append(next(dataloader_iter)) @@ -230,7 +239,8 @@ def training_step(self, dataloader_iter, batch_idx): assert isinstance(batch, torch.Tensor) or batch is None self.count += 2 if self.automatic_optimization: - return super().training_step(batch, 0) + loss = super().training_step(batch, 0) + self.log("train_loss", loss["loss"]) else: opt = self.optimizers() output = self(batch) @@ -243,6 +253,141 @@ def training_step(self, dataloader_iter, batch_idx): model = TestModel(automatic_optimization=automatic_optimization) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - trainer.data_connector.data_fetcher = DataLoaderIterDataFetcher() trainer.fit(model) + + breakpoint() + + assert trainer.fit_loop.epoch_loop.batch_progress == 64 + assert trainer.data_connector.train_data_fetcher.fetched == 64 assert model.count == 64 + + +class DummyWaitable: + def __init__(self, val: Any) -> None: + self.val = val + + def wait(self) -> Any: + return self.val + + +class AsyncBoringModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.automatic_optimization = False + self.batch_i_handle = None + self.num_batches_processed = 0 + + def _async_op(self, batch: Any) -> DummyWaitable: + return DummyWaitable(val=batch) + + def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: + if self.batch_i_handle is None: + batch_i_raw = next(dataloader_iter) + self.batch_i_handle = self._async_op(batch_i_raw) + + # Invariant: _async_op for batch[i] has been initiated + batch_ip1_handle = None + is_last = False + try: + batch_ip1_raw = next(dataloader_iter) + batch_ip1_handle = self._async_op(batch_ip1_raw) + except StopIteration: + is_last = True + + batch_i = self.batch_i_handle.wait() + + pred = self.layer(batch_i) + loss = self.loss(batch_i, pred) + loss.backward() + self.optimizers().step() + self.optimizers().zero_grad() + + self.batch_i_handle = batch_ip1_handle + self.num_batches_processed += 1 + + return {"loss": loss, "is_last": is_last} + + def train_dataloader(self): + return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN)) + + +def test_training_step_with_dataloader_access(tmpdir) -> None: + """ + A baseline functional test for `training_step` with dataloader access. + """ + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = AsyncBoringModel() + trainer.fit(m) + assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed." + + +def test_stop_iteration(tmpdir) -> None: + """ + Verify that when `StopIteration` is raised within `training_step`, `fit()` + terminiates as expected. + """ + EXPECT_NUM_BATCHES_PROCESSED = 2 + + class TestModel(AsyncBoringModel): + def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: + output = super().training_step(dataloader_iter) + if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED: + raise StopIteration() + return output + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = TestModel() + trainer.fit(m) + assert ( + m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED + ), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed." + + +def test_on_train_batch_start_overridden(tmpdir) -> None: + """ + Verify that a `MisconfigurationException` is raised when + `on_train_batch_start` is overridden on the `LightningModule`. + """ + + class InvalidModel(AsyncBoringModel): + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + pass + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = InvalidModel() + with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_start` is not compatible with"): + trainer.fit(m) + + +def test_on_train_batch_end_overridden(tmpdir) -> None: + """ + Verify that a `MisconfigurationException` is raised when + `on_train_batch_end` is overridden on the `LightningModule`. + """ + + class InvalidModel(AsyncBoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + pass + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = InvalidModel() + with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_end` is not compatible with"): + trainer.fit(m) + + +def test_tbptt_split_batch_overridden(tmpdir) -> None: + """ + Verify that a `MisconfigurationException` is raised when + `tbptt_split_batch` is overridden on the `LightningModule`. + """ + + class InvalidModel(AsyncBoringModel): + + def __init__(self) -> None: + super().__init__() + self.truncated_bptt_steps = 2 + + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) + m = InvalidModel() + with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."): + trainer.fit(m) From a70c052ae1529f6ace01e77dc448a4544f987a53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 12:42:47 +0000 Subject: [PATCH 21/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/lightning.py | 3 ++- .../loops/batch/training_batch_loop.py | 4 ++-- .../logger_connector/logger_connector.py | 1 + pytorch_lightning/utilities/fetching.py | 2 +- tests/utilities/test_fetching.py | 15 ++++----------- 5 files changed, 10 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ac1194615da1e..ad1c69b0ba34b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -456,7 +456,8 @@ def log( if is_param_in_hook_signature(self.training_step, "dataloader_iter") and batch_size is None: raise MisconfigurationException( - "When the `dataloader_iter` is requested within the `training_step`, `batch_size` should be provided.") + "When the `dataloader_iter` is requested within the `training_step`, `batch_size` should be provided." + ) results.log( self._current_fx_name, diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 20eb5e61ec7b3..58b8a27344e50 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -545,12 +545,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict({"batch": batch}) + step_kwargs = OrderedDict({"batch": batch}) lightning_module = self.trainer.lightning_module training_step_fx = getattr(self.trainer.lightning_module, "training_step") - + if is_param_in_hook_signature(training_step_fx, "batch_idx"): step_kwargs["batch_idx"] = batch_idx diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 5be0e9aae6f98..6257da3725181 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Iterable, Mapping, Optional, Union import torch + import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index d7cf8e6740d0b..96566dbf92bc5 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -17,13 +17,13 @@ from contextlib import contextmanager from copy import deepcopy from functools import partial -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from typing import Any, Callable, Generator, List, Optional, Tuple import torch from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index b8944c137f118..0107f30b1a4dc 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -13,17 +13,9 @@ # limitations under the License. import os from time import time -from typing import Any -from unittest import mock from typing import Any, Iterator +from unittest import mock -import pytest -from torch.utils.data import DataLoader - -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers import BoringModel, RandomDataset import pytest import torch from torch import tensor @@ -33,6 +25,8 @@ from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher +from pytorch_lightning.utilities.types import STEP_OUTPUT +from tests.helpers import BoringModel, RandomDataset from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -254,7 +248,7 @@ def training_step(self, dataloader_iter, batch_idx): model = TestModel(automatic_optimization=automatic_optimization) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) - + breakpoint() assert trainer.fit_loop.epoch_loop.batch_progress == 64 @@ -382,7 +376,6 @@ def test_tbptt_split_batch_overridden(tmpdir) -> None: """ class InvalidModel(AsyncBoringModel): - def __init__(self) -> None: super().__init__() self.truncated_bptt_steps = 2 From d08976fb4ab6575eac5e522418e3e7d2d3ac9b93 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 09:19:13 -0400 Subject: [PATCH 22/74] resolve tests --- pytorch_lightning/loops/base.py | 15 +++++++++++---- .../loops/batch/training_batch_loop.py | 5 +++++ pytorch_lightning/utilities/fetching.py | 1 + tests/utilities/test_fetching.py | 15 +++++++-------- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index ee5c3a1b708f1..25f5f48cae02f 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -87,6 +87,12 @@ def on_skip(self) -> Optional[Any]: Returns: the default output value of :meth:`on_run_end` """ + + def _advance_step(self, *args, **kwargs): + self.on_advance_start(*args, **kwargs) + self.advance(*args, **kwargs) + self.on_advance_end() + self.restarting = False def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: """ @@ -107,11 +113,9 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: while not self.done: try: - self.on_advance_start(*args, **kwargs) - self.advance(*args, **kwargs) - self.on_advance_end() - self.restarting = False + self._advance_step(*args, **kwargs) except StopIteration: + self.on_stop_iteration() break output = self.on_run_end() @@ -160,6 +164,9 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" + def on_stop_iteration(self): + """Called when a stop iteration is being triggered.""" + def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict: """ The state dict is determined by the state and progress of this loop and all its children. diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 20eb5e61ec7b3..9494d26926a2f 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -104,6 +104,11 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict: self.batch_outputs = None # free memory return output + def on_stop_iteration(self): + # if a stop iteration is being triggered within the ``training_step``, + # forward the expection to the parent loop. + raise StopIteration + def reset(self) -> None: """Resets the loop state""" self._hiddens = None diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index d7cf8e6740d0b..802515eacb047 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -378,6 +378,7 @@ def __iter__(self) -> "StepFuncDataLoaderIter": def __next__(self) -> Any: try: data = next(self.iterator) + self.data_fetcher.fetched += 1 return data except StopIteration: self.data_fetcher.done = True diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index b8944c137f118..1b821563164bc 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -16,8 +16,6 @@ from typing import Any from unittest import mock from typing import Any, Iterator - -import pytest from torch.utils.data import DataLoader from pytorch_lightning import Trainer @@ -229,7 +227,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs): self.batches = [] def training_step(self, dataloader_iter, batch_idx): - # assert self.count == batch_idx + assert self.count == batch_idx assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher) # fetch 2 batches self.batches.append(next(dataloader_iter)) @@ -240,7 +238,9 @@ def training_step(self, dataloader_iter, batch_idx): self.count += 2 if self.automatic_optimization: loss = super().training_step(batch, 0) - self.log("train_loss", loss["loss"]) + with pytest.raises(MisconfigurationException, match="`batch_size` should be provided"): + self.log("train_loss", loss["loss"]) + self.log("train_loss", loss["loss"], batch_size=1) else: opt = self.optimizers() output = self(batch) @@ -255,9 +255,8 @@ def training_step(self, dataloader_iter, batch_idx): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) - breakpoint() - - assert trainer.fit_loop.epoch_loop.batch_progress == 64 + # we don't sync batch_progress with user fetching + assert trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33 assert trainer.data_connector.train_data_fetcher.fetched == 64 assert model.count == 64 @@ -332,7 +331,7 @@ class TestModel(AsyncBoringModel): def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: output = super().training_step(dataloader_iter) if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED: - raise StopIteration() + raise StopIteration return output trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) From e7272e3754de56ae96032ae474fb70834c337b1a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 10:17:43 -0400 Subject: [PATCH 23/74] update --- .../processors/iterator_batch_processor.py | 2 +- .../trainer/connectors/data_connector.py | 18 +++++++----------- pytorch_lightning/utilities/fetching.py | 2 +- tests/loops/test_iterator_batch_processor.py | 2 ++ tests/utilities/test_fetching.py | 4 ++-- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 122f0bbd95440..1ac5de745e6e5 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -113,7 +113,7 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: Args: dataloader_iter: the iterator over the dataloader producing the new batch """ - _, (dataloader_iter, batch_idx, is_last) = next(dataloader_iter) + batch_idx, (dataloader_iter, is_last) = next(dataloader_iter) self.trainer.logger_connector.on_batch_start() response = self.trainer.call_hook("on_batch_start") diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b137b225cfd1e..3afa274330be5 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -31,14 +31,10 @@ class DataConnector: - def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): + def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle", data_fetcher: Optional[AbstractDataFetcher] = None): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - - self.train_data_fetcher: Optional[AbstractDataFetcher] = None - self.validate_data_fetcher: Optional[AbstractDataFetcher] = None - self.test_data_fetcher: Optional[AbstractDataFetcher] = None - self.sanity_checking_data_fetcher: Optional[AbstractDataFetcher] = None + self.data_fetcher = data_fetcher def on_trainer_init( self, @@ -100,16 +96,16 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: stage: str = self.trainer.state.stage.value - data_fetcher = self._select_data_fetcher() - data_fetcher.setup( + self.data_fetcher = self._select_data_fetcher() + self.data_fetcher.setup( dataloader, stage=stage, batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), profiler=self.trainer.profiler, ) - # store to enable teardown and clean extra fetched batches - setattr(self, f"{stage}_data_fetcher", data_fetcher) - return enumerate(data_fetcher) + if isinstance(self.data_fetcher, DataLoaderIterDataFetcher): + return self.data_fetcher + return enumerate(self.data_fetcher) def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 59442730245e0..66de559d24878 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -409,4 +409,4 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: def fetching_function(self) -> Generator: iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) while not self.done: - yield iterator, self.fetched, self.done + yield self.fetched, (iterator, self.done) diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py index 2cd6a172f6941..f5d4e4b9fe4cf 100644 --- a/tests/loops/test_iterator_batch_processor.py +++ b/tests/loops/test_iterator_batch_processor.py @@ -73,6 +73,8 @@ def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: def train_dataloader(self): return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN)) + training_epoch_end = None + def test_training_step_with_dataloader_access(tmpdir) -> None: """ diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 300894522de94..5415cc6b03b21 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -194,7 +194,7 @@ def test_trainer_num_prefetch_batches(tmpdir): trainer.fit(model) t1 = time() global_step = trainer.global_step - assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) + assert isinstance(trainer.data_connector.data_fetcher, InterBatchParallelDataFetcher) torch.cuda.synchronize() @@ -204,7 +204,7 @@ def test_trainer_num_prefetch_batches(tmpdir): t3 = time() assert global_step == trainer.global_step == 4 - assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) + assert isinstance(trainer.data_connector.data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) assert ratio > 1.1, ratio From 99a88534e1b74e8d877f8f2faf812da380978eb5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 14:20:21 +0000 Subject: [PATCH 24/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/data_connector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 3afa274330be5..3ea671bcb8fa2 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -31,7 +31,12 @@ class DataConnector: - def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle", data_fetcher: Optional[AbstractDataFetcher] = None): + def __init__( + self, + trainer: "pl.Trainer", + multiple_trainloader_mode: str = "max_size_cycle", + data_fetcher: Optional[AbstractDataFetcher] = None, + ): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode self.data_fetcher = data_fetcher From c7b5ff0f6d0d64ac0627906882bc431007edddf0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 14:25:45 +0000 Subject: [PATCH 25/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/base.py | 2 +- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/trainer/connectors/data_connector.py | 7 ++++++- tests/utilities/test_fetching.py | 8 ++------ 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 25f5f48cae02f..899594db4b310 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -87,7 +87,7 @@ def on_skip(self) -> Optional[Any]: Returns: the default output value of :meth:`on_run_end` """ - + def _advance_step(self, *args, **kwargs): self.on_advance_start(*args, **kwargs) self.advance(*args, **kwargs) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index fb22fdc1605be..5b89e0ed9197e 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -105,7 +105,7 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict: return output def on_stop_iteration(self): - # if a stop iteration is being triggered within the ``training_step``, + # if a stop iteration is being triggered within the ``training_step``, # forward the expection to the parent loop. raise StopIteration diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index eff12b2c29fc1..cdc52605e1a5d 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -31,7 +31,12 @@ class DataConnector: - def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle", data_fetcher: Optional[AbstractDataFetcher] = None): + def __init__( + self, + trainer: "pl.Trainer", + multiple_trainloader_mode: str = "max_size_cycle", + data_fetcher: Optional[AbstractDataFetcher] = None, + ): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode self.data_fetcher = data_fetcher diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 51be828ab3625..2dc6284b6311b 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -14,12 +14,8 @@ import os from time import time from typing import Any, Iterator -from torch.utils.data import DataLoader from unittest import mock -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers import BoringModel, RandomDataset + import pytest import torch from torch import tensor @@ -254,7 +250,7 @@ def training_step(self, dataloader_iter, batch_idx): model = TestModel(automatic_optimization=automatic_optimization) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) - + # we don't sync batch_progress with user fetching assert trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33 assert trainer.data_connector.data_fetcher.fetched == 64 From f571f8cbd261c4727caa7f67cfcaa96b1f46de1b Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 10:30:55 -0400 Subject: [PATCH 26/74] add teardown --- pytorch_lightning/trainer/connectors/data_connector.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index eff12b2c29fc1..fc576005c783d 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -222,6 +222,11 @@ def detach_data(model: "pl.LightningModule") -> None: if isinstance(loader, _PatchDataLoader): loader.unpatch(model) + def teardown(self) -> None: + for attr in vars(self).values(): + if isinstance(attr, AbstractDataFetcher): + attr.teardown() + def _check_training_step_requires_dataloader_iter(self) -> bool: if not self.trainer.training: return False diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 02b0be775855a..c445c638f8126 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1062,6 +1062,7 @@ def _post_dispatch(self): # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns # which need to happen before. self.accelerator.teardown() + self.data_connector.teardown() self._active_loop.teardown() self.logger_connector.teardown() From 1b5b911ddfaae4d031d7b8ce44be62bb9c2138fd Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 10:33:49 -0400 Subject: [PATCH 27/74] update --- pytorch_lightning/loops/processors/iterator_batch_processor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 1ac5de745e6e5..d33c23271e6dd 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -146,8 +146,6 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: if training_step_output: batch_outputs[0].append(training_step_output) - if training_step_output: - batch_outputs[0].append(training_step_output) return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last) def teardown(self) -> None: From 403ef3cd2c5a1b12430d77d957dece537048d808 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 10:47:14 -0400 Subject: [PATCH 28/74] update on comments --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 10 ++++------ pytorch_lightning/loops/fit_loop.py | 9 +++------ .../loops/processors/iterator_batch_processor.py | 4 ++-- tests/loops/test_iterator_batch_processor.py | 2 -- tests/utilities/test_fetching.py | 10 +--------- 5 files changed, 10 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 0a221b66fa5b4..ae1d4b08cd2e5 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -98,7 +98,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) - dataloader_iter = self._prepare_dataloader_iter() + dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) + dataloader_iter = iter(dataloader) + dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.epoch_loop.run( @@ -247,8 +250,3 @@ def on_evaluation_epoch_end(self) -> None: def teardown(self) -> None: self._results.cpu() self.epoch_loop.teardown() - - def _prepare_dataloader_iter(self) -> Iterator: - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) - return iter(dataloader) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 830d89d5b32e6..8507809c361ec 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -192,7 +192,9 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" - dataloader_iter = self._prepare_dataloader_iter() + dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) + dataloader_iter = iter(dataloader) with self.trainer.profiler.profile("run_training_epoch"): # run train epoch @@ -233,8 +235,3 @@ def should_accumulate(self) -> bool: def teardown(self) -> None: self.epoch_loop.teardown() - - def _prepare_dataloader_iter(self) -> Iterator: - dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) - return iter(dataloader) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index d33c23271e6dd..35f5426e96d31 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -130,7 +130,6 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: with self.trainer.profiler.profile("model_forward"): with self.trainer.profiler.profile("training_step"): - step_kwargs = self._build_kwargs(dataloader_iter, batch_idx) training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() @@ -143,6 +142,7 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: check_finite_loss(self.trainer.lightning_module, training_step_output.minimize) batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] + if training_step_output: batch_outputs[0].append(training_step_output) @@ -159,7 +159,7 @@ def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, """Builds the keyword arguments for training_step Args: - batch: the batch to train on + dataloader_iter: The dataloader to pass batch_idx: the index of the current batch Returns: diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py index f5d4e4b9fe4cf..2cd6a172f6941 100644 --- a/tests/loops/test_iterator_batch_processor.py +++ b/tests/loops/test_iterator_batch_processor.py @@ -73,8 +73,6 @@ def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: def train_dataloader(self): return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN)) - training_epoch_end = None - def test_training_step_with_dataloader_access(tmpdir) -> None: """ diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 5415cc6b03b21..27d8175d37376 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -29,7 +29,7 @@ from tests.helpers.runif import RunIf -@pytest.mark.parametrize("use_combined_loader", [False]) +@pytest.mark.parametrize("use_combined_loader", [False, True]) def test_prefetch_iterator(use_combined_loader): """Test the DataFetcher with PyTorch IterableDataset.""" @@ -111,14 +111,6 @@ def measure() -> float: cycles_per_ms = 1000000 / start.elapsed_time(end) return cycles_per_ms - # Get 10 values and remove the 2 max and 2 min and return the avg. - # This is to avoid system disturbance that skew the results, e.g. - # the very first cuda call likely does a bunch of init, which takes - # much longer than subsequent calls. - # - # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs - # and seems to return stable values. Therefore, we enable caching - # using lru_cache decorator above. num = 10 vals = [] for _ in range(num): From 96ee94965c420f1f92b06fd583d2639a60b740ce Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 10:50:11 -0400 Subject: [PATCH 29/74] add back comment --- .../loops/processors/iterator_batch_processor.py | 5 ++--- tests/utilities/test_fetching.py | 4 ++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index 35f5426e96d31..eb98ee4f24328 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -142,7 +142,6 @@ def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: check_finite_loss(self.trainer.lightning_module, training_step_output.minimize) batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - if training_step_output: batch_outputs[0].append(training_step_output) @@ -163,10 +162,10 @@ def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, batch_idx: the index of the current batch Returns: - the keyword arguments for the training step + An ordered dict with the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict({"dataloader_iter": dataloader_iter}) + step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)]) lightning_module = self.trainer.lightning_module diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 27d8175d37376..db3a0e9b6534c 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -111,6 +111,10 @@ def measure() -> float: cycles_per_ms = 1000000 / start.elapsed_time(end) return cycles_per_ms + # Get 10 values and remove the 2 max and 2 min and return the avg. + # This is to avoid system disturbance that skew the results, e.g. the very first cuda call likely does a bunch of + # init, which takes much longer than subsequent calls. + num = 10 vals = [] for _ in range(num): From 68acd44b64e74669cbbbc555f7d2b50bb46eb747 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 23 Aug 2021 17:37:43 +0200 Subject: [PATCH 30/74] Fix diff --- .../loops/processors/iterator_batch_processor.py | 4 +--- tests/utilities/test_fetching.py | 13 +++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py index eb98ee4f24328..c1981173215ae 100644 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ b/pytorch_lightning/loops/processors/iterator_batch_processor.py @@ -167,9 +167,7 @@ def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)]) - lightning_module = self.trainer.lightning_module - - training_step_fx = getattr(lightning_module, "training_step") + training_step_fx = getattr(self.trainer.lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "batch_idx"): step_kwargs["batch_idx"] = batch_idx diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index db3a0e9b6534c..e012f6aa784d3 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -97,11 +97,16 @@ def test_misconfiguration_error(): def get_cycles_per_ms() -> float: """ - Measure and return approximate number of cycles per millisecond for torch.cuda._sleep - Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py + Get 10 values and remove the 2 max and 2 min and return the avg. + This is to avoid system disturbance that skew the results, e.g. the very first cuda call likely does a bunch of + init, which takes much longer than subsequent calls. """ def measure() -> float: + """ + Measure and return approximate number of cycles per millisecond for `torch.cuda._sleep` + Copied from: https://github.com/pytorch/pytorch/blob/v1.9.0/test/test_cuda.py#L81 + """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() @@ -111,10 +116,6 @@ def measure() -> float: cycles_per_ms = 1000000 / start.elapsed_time(end) return cycles_per_ms - # Get 10 values and remove the 2 max and 2 min and return the avg. - # This is to avoid system disturbance that skew the results, e.g. the very first cuda call likely does a bunch of - # init, which takes much longer than subsequent calls. - num = 10 vals = [] for _ in range(num): From a9957643f79de28d303e112086a83c8546db9206 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 12:08:55 -0400 Subject: [PATCH 31/74] update on comments --- pytorch_lightning/trainer/connectors/data_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 3ea671bcb8fa2..f57027cdbb154 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -93,8 +93,10 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: "this signature is experimental and the behavior may subject to change." ) return DataLoaderIterDataFetcher() - elif self.trainer.training_type_plugin.on_gpu and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": + elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": # note: this is an experimental feature + if not self.trainer.training_type_plugin.on_gpu: + raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher() else: return DataFetcher() From 307f446e1014143056749c8b19016ab728da513f Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 12:15:37 -0400 Subject: [PATCH 32/74] resolve tests --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index ae1d4b08cd2e5..9a7459871355b 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -99,7 +99,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: void(*args, **kwargs) dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) + dataloader = self.trainer.data_connector.get_profiled_dataloader( + dataloader, dataloader_idx=self.current_dataloader_idx) dataloader_iter = iter(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] From d69b6fb25314b1a1da0ea26fb21f80df5a3b4316 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 16:16:38 +0000 Subject: [PATCH 33/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 9a7459871355b..3c24c921da389 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -100,7 +100,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader( - dataloader, dataloader_idx=self.current_dataloader_idx) + dataloader, dataloader_idx=self.current_dataloader_idx + ) dataloader_iter = iter(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] From 104256c0468ec57b33a194265f531d3543bdc601 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 12:19:04 -0400 Subject: [PATCH 34/74] update --- .../processors/iterator_batch_processor.py | 174 ----------------- tests/loops/test_iterator_batch_processor.py | 183 ------------------ 2 files changed, 357 deletions(-) delete mode 100644 pytorch_lightning/loops/processors/iterator_batch_processor.py delete mode 100644 tests/loops/test_iterator_batch_processor.py diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py deleted file mode 100644 index c1981173215ae..0000000000000 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Tuple - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.loops.utilities import ( - _check_training_step_output, - _process_training_step_output, - check_finite_loss, -) -from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature - -log = logging.getLogger(__name__) - - -class IteratorBatchProcessor: - """ - The processor for performing a training iteration when ``training_step`` needs access to the - dataloader. It is selected when the signature of ``training_step`` contains ``dataloader_iter``: - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - - The ``training_step`` is allowed to fetch multiple batches during one training iteration. The - framework provides minimum amount of automation with regards to model optimization. The - flexibility allows for ease of experimentation with inter-batch parallelism techniques. - - This processor doesn't support ``automatic_optimization`` and ``tbptt``. An error will be thrown - if the ``LightningModule`` or the ``Trainer`` is configured to use these features. - - The ``training_step`` is responsible for reporting whether it has reached the last batch by - including an ``is_last`` field in the result dict. Failing to do so will result in an error. - - The ``training_step`` should only optimize the model with one batch for the sake of API and - reporting consistency (TODO: consider removing this limitation). - - Args: - trainer: a reference to the trainer - model: a reference to the lightning module (for config validation purposes only) - """ - - def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: - if is_overridden("on_train_batch_start", model): - raise MisconfigurationException( - "The model hook `on_train_batch_start` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("on_train_batch_end", model): - raise MisconfigurationException( - "The model hook `on_train_batch_end` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("tbptt_split_batch", model): - raise MisconfigurationException( - "The model hook `tbptt_split_batch` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if trainer.accumulate_grad_batches != 1: - raise MisconfigurationException( - "`accumulate_grad_batches` can only be 1 when your " - "`training_step` takes `dataloader_iter` as an argument." - ) - - self.trainer = trainer - - # The following field is not used by the processor since it doesn't support automatic - # optimization and tbptt. Initializing them regardless since they are currently expected by - # `FitLoop` or `TrainingEpochLoop`. - # TODO: come up with an abstraction for "batch processors" so they can be better decoupled - # with parent loops. - self.accumulated_loss: Optional[torch.Tensor] = None - self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1) - self.optim_progress = OptimizationProgress() - self.split_idx: int = 0 - self._skip_backward = False - - def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: - """ - Returns the number of active optimizers. - """ - return len(self.trainer.optimizers) - - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]: - """ - Returns the currently active optimizers. - - Returns: - A list of tuples (opt_idx, optimizer) of currently active optimizers. - """ - return list(enumerate(self.trainer.optimizers)) - - def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: - """ - Args: - dataloader_iter: the iterator over the dataloader producing the new batch - """ - batch_idx, (dataloader_iter, is_last) = next(dataloader_iter) - - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - - # give the PL module a result for logging - model = self.trainer.lightning_module - # manually capture logged metrics - model._current_fx_name = "training_step" - step_kwargs = self._build_kwargs(dataloader_iter, batch_idx) - - with self.trainer.profiler.profile("model_forward"): - with self.trainer.profiler.profile("training_step"): - training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(self.trainer.lightning_module, training_step_output) - - training_step_output, _ = _process_training_step_output(self.trainer, training_step_output) - - if self.trainer.terminate_on_nan: - check_finite_loss(self.trainer.lightning_module, training_step_output.minimize) - - batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - if training_step_output: - batch_outputs[0].append(training_step_output) - - return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last) - - def teardown(self) -> None: - """ - No-op. Only defined to comply with FitLoop's expectation. - """ - pass - - # FIXME: To be deleted in next PR. - def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]: - """Builds the keyword arguments for training_step - - Args: - dataloader_iter: The dataloader to pass - batch_idx: the index of the current batch - - Returns: - An ordered dict with the keyword arguments for the training step - """ - # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)]) - - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx"): - step_kwargs["batch_idx"] = batch_idx - - return step_kwargs diff --git a/tests/loops/test_iterator_batch_processor.py b/tests/loops/test_iterator_batch_processor.py deleted file mode 100644 index 2cd6a172f6941..0000000000000 --- a/tests/loops/test_iterator_batch_processor.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterator - -import pytest -from torch.utils.data import DataLoader - -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests.helpers import BoringModel, RandomDataset - -_BATCH_SIZE = 32 -_DATASET_LEN = 64 - - -class DummyWaitable: - def __init__(self, val: Any) -> None: - self.val = val - - def wait(self) -> Any: - return self.val - - -class AsyncBoringModel(BoringModel): - def __init__(self) -> None: - super().__init__() - self.automatic_optimization = False - self.batch_i_handle = None - self.num_batches_processed = 0 - - def _async_op(self, batch: Any) -> DummyWaitable: - return DummyWaitable(val=batch) - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - if self.batch_i_handle is None: - batch_i_raw = next(dataloader_iter) - self.batch_i_handle = self._async_op(batch_i_raw) - - # Invariant: _async_op for batch[i] has been initiated - batch_ip1_handle = None - is_last = False - try: - batch_ip1_raw = next(dataloader_iter) - batch_ip1_handle = self._async_op(batch_ip1_raw) - except StopIteration: - is_last = True - - batch_i = self.batch_i_handle.wait() - - pred = self.layer(batch_i) - loss = self.loss(batch_i, pred) - loss.backward() - self.optimizers().step() - self.optimizers().zero_grad() - - self.batch_i_handle = batch_ip1_handle - self.num_batches_processed += 1 - - return {"loss": loss, "is_last": is_last} - - def train_dataloader(self): - return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN)) - - -def test_training_step_with_dataloader_access(tmpdir) -> None: - """ - A baseline functional test for `training_step` with dataloader access. - """ - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = AsyncBoringModel() - trainer.fit(m) - assert m.num_batches_processed == _DATASET_LEN, f"Expect all {_DATASET_LEN} batches to be processed." - - -def test_stop_iteration(tmpdir) -> None: - """ - Verify that when `StopIteration` is raised within `training_step`, `fit()` - terminiates as expected. - """ - EXPECT_NUM_BATCHES_PROCESSED = 2 - - class TestModel(AsyncBoringModel): - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED: - raise StopIteration() - return output - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = TestModel() - trainer.fit(m) - assert ( - m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED - ), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed." - - -def test_on_train_batch_start_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_start` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_on_train_batch_end_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_end` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_tbptt_split_batch_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `tbptt_split_batch` is overridden on the `LightningModule`. - """ - - class InvalidModel(AsyncBoringModel): - def tbptt_split_batch(self, batch, split_size): - pass - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_accumulate_grad_batches(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `accumulate_grad_batches` is not set to 1. - """ - trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir) - m = AsyncBoringModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) - - -def test_is_last_not_set(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when `training_step` - doesn't include "is_last" in the result dict. - """ - - class InvalidModel(AsyncBoringModel): - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - del output["is_last"] - return output - - trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException): - trainer.fit(m) From 89fce204932b74775268b32b94d3da796587fa8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 16:20:24 +0000 Subject: [PATCH 35/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 9a7459871355b..3c24c921da389 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -100,7 +100,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader( - dataloader, dataloader_idx=self.current_dataloader_idx) + dataloader, dataloader_idx=self.current_dataloader_idx + ) dataloader_iter = iter(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] From e415c604ee138def161ae055d16d52d4fddad6da Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 12:20:34 -0400 Subject: [PATCH 36/74] update --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f41e3a24afb4..685da8583c4b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035)) +- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052)) + + ### Fixed - Ensure the existence of `DDPPlugin._sync_dir` in `reconciliate_processes` ([#8939](https://github.com/PyTorchLightning/pytorch-lightning/pull/8939)) From d81c9715226b4a03b4b3015cb27d9958ac02b77d Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 12:25:22 -0400 Subject: [PATCH 37/74] cleanup --- pytorch_lightning/loops/base.py | 11 ++++------- pytorch_lightning/loops/epoch/training_epoch_loop.py | 2 +- .../trainer/connectors/data_connector.py | 2 ++ .../connectors/logger_connector/logger_connector.py | 3 +++ 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 899594db4b310..1f0c7f8d69e4c 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -88,12 +88,6 @@ def on_skip(self) -> Optional[Any]: the default output value of :meth:`on_run_end` """ - def _advance_step(self, *args, **kwargs): - self.on_advance_start(*args, **kwargs) - self.advance(*args, **kwargs) - self.on_advance_end() - self.restarting = False - def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: """ The main entry point to the loop. @@ -113,7 +107,10 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: while not self.done: try: - self._advance_step(*args, **kwargs) + self.on_advance_start(*args, **kwargs) + self.advance(*args, **kwargs) + self.on_advance_end() + self.restarting = False except StopIteration: self.on_stop_iteration() break diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 8368fc6eeaaf8..1a89e5efc5c22 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -44,7 +44,7 @@ def __init__(self, min_steps: int, max_steps: int): self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() - self.batch_loop: TrainingBatchLoop = None + self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 8123e0c7a6bd7..e2e8f33ee7be0 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -235,6 +235,8 @@ def teardown(self) -> None: attr.teardown() def _check_training_step_requires_dataloader_iter(self) -> bool: + """Check if the current `training_step` is requesting `dataloader_iter`.""" + if not self.trainer.training: return False diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 6257da3725181..ae7bf69dde00d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -199,8 +199,11 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: + # when the user request `dataloader_iter`, we can't track the batch_size + # and this is left to user responsability. if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): self.trainer._results.extract_batch_size(split_batch) + self._batch_idx = batch_idx self._split_idx = split_idx From d47390388b0f7fcfc13e886a6b22087211cd8869 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 16:26:55 +0000 Subject: [PATCH 38/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/logger_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ae7bf69dde00d..7151a68362e4f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -199,11 +199,11 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: """ def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: - # when the user request `dataloader_iter`, we can't track the batch_size + # when the user request `dataloader_iter`, we can't track the batch_size # and this is left to user responsability. if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): self.trainer._results.extract_batch_size(split_batch) - + self._batch_idx = batch_idx self._split_idx = split_idx From 1ed2f3185b61986199dc04b63b176840e6e2f656 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 12:46:15 -0400 Subject: [PATCH 39/74] update --- pytorch_lightning/trainer/connectors/data_connector.py | 9 +++------ tests/trainer/loops/test_evaluation_loop.py | 3 ++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index f57027cdbb154..03f4176bae1ca 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -80,17 +80,15 @@ def on_trainer_init( self.trainer._is_data_prepared = False def _check_training_step_requires_dataloader_iter(self) -> bool: - if not self.trainer.training: - return False training_step_fx = getattr(self.trainer.lightning_module, "training_step") contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) return contains_dataloader_iter def _select_data_fetcher(self) -> AbstractDataFetcher: - if self._check_training_step_requires_dataloader_iter(): + if self.trainer.training and self._check_training_step_requires_dataloader_iter(): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " - "this signature is experimental and the behavior may subject to change." + "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher() elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": @@ -98,8 +96,7 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: if not self.trainer.training_type_plugin.on_gpu: raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher() - else: - return DataFetcher() + return DataFetcher() def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: stage: str = self.trainer.state.stage.value diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index d7acd7e65727e..1f0fa502c0416 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -95,7 +95,8 @@ def validation_step(self, batch, batch_idx): upper = 201 * self.num_params * 4 current = torch.cuda.memory_allocated(0) assert lower < current - assert current - initial_memory < upper + # FIXME: Where is the memory link coming from ? + assert current - initial_memory < upper + 3000 return super().validation_step(batch, batch_idx) torch.cuda.empty_cache() From 87432e0efeb04ce77f59149634d0020a5d9d696e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 16:47:20 +0000 Subject: [PATCH 40/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/loops/test_evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 1f0fa502c0416..038ac2c28ec8d 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -95,7 +95,7 @@ def validation_step(self, batch, batch_idx): upper = 201 * self.num_params * 4 current = torch.cuda.memory_allocated(0) assert lower < current - # FIXME: Where is the memory link coming from ? + # FIXME: Where is the memory link coming from ? assert current - initial_memory < upper + 3000 return super().validation_step(batch, batch_idx) From fb98c86e17241d7b8915ed5399d08a5fe4633afe Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 12:52:02 -0400 Subject: [PATCH 41/74] update --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/loops/epoch/training_epoch_loop.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 5b89e0ed9197e..3274b2ccbc3bf 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -550,7 +550,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict({"batch": batch}) + step_kwargs = OrderedDict([("batch", batch)]) lightning_module = self.trainer.lightning_module diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 1a89e5efc5c22..a241eb1e666a1 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -114,10 +114,11 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: """ batch_idx, (batch, is_last) = next(dataloader_iter) + self.batch_progress.increment_ready() + # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ - self.batch_progress.increment_ready() with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) From 5bf90900a61d31a72b77ba219136ddbec8a8af39 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 14:23:43 -0400 Subject: [PATCH 42/74] resolve memory leak --- .../loops/batch/training_batch_loop.py | 2 + .../loops/epoch/evaluation_epoch_loop.py | 3 + .../loops/epoch/training_epoch_loop.py | 3 + .../trainer/connectors/data_connector.py | 33 ++++++-- pytorch_lightning/utilities/fetching.py | 78 ++++++++++++------- tests/trainer/loops/test_evaluation_loop.py | 3 +- tests/utilities/test_fetching.py | 4 +- 7 files changed, 85 insertions(+), 41 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index a048bc0c3a91c..163d7681f29a8 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -280,6 +280,8 @@ def _training_step( training_step_output = self.trainer.accelerator.training_step(step_kwargs) self.trainer.accelerator.post_training_step() + del step_kwargs + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) _check_training_step_output(self.trainer.lightning_module, training_step_output) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 91680ae8dc2a7..efef057cec4a6 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -91,6 +91,9 @@ def advance( if batch is None: raise StopIteration + if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device: + batch = self.trainer.accelerator.batch_to_device(batch) + self.batch_progress.increment_ready() # hook diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 6e5841c1e028b..da44188acd162 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -132,6 +132,9 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: else: _, (batch, is_last) = next(dataloader_iter) + if not self.trainer.data_connector.train_data_fetcher.store_on_device: + batch = self.trainer.accelerator.batch_to_device(batch) + # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 03f4176bae1ca..b66cd48e4ac95 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -35,11 +35,23 @@ def __init__( self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle", - data_fetcher: Optional[AbstractDataFetcher] = None, + train_data_fetcher: Optional[AbstractDataFetcher] = None, + validate_data_fetcher: Optional[AbstractDataFetcher] = None, + test_data_fetcher: Optional[AbstractDataFetcher] = None, ): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - self.data_fetcher = data_fetcher + + self.train_data_fetcher = train_data_fetcher + self.validate_data_fetcher = validate_data_fetcher + self.test_data_fetcher = test_data_fetcher + self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None + + @property + def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: + if self.trainer.sanity_checking: + return self.sanity_check_data_fetcher + return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher def on_trainer_init( self, @@ -85,31 +97,36 @@ def _check_training_step_requires_dataloader_iter(self) -> bool: return contains_dataloader_iter def _select_data_fetcher(self) -> AbstractDataFetcher: + if self.trainer.sanity_checking: + return DataFetcher() + if self.trainer.training and self._check_training_step_requires_dataloader_iter(): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher() - elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": + elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": # note: this is an experimental feature if not self.trainer.training_type_plugin.on_gpu: raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher() + return DataFetcher() def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: stage: str = self.trainer.state.stage.value - self.data_fetcher = self._select_data_fetcher() - self.data_fetcher.setup( + data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() + data_fetcher.setup( dataloader, stage=stage, batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), profiler=self.trainer.profiler, ) - if isinstance(self.data_fetcher, DataLoaderIterDataFetcher): - return self.data_fetcher - return enumerate(self.data_fetcher) + setattr(self, f"{stage}_data_fetcher", data_fetcher) + if isinstance(data_fetcher, DataLoaderIterDataFetcher): + return data_fetcher + return enumerate(data_fetcher) def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 66de559d24878..b45ddad593bff 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -47,15 +47,19 @@ class SimpleDataFetcher(AbstractDataFetcher): def fetching_function(self): while True: try: - yield next(self.dataloader_iter), False + return next(self.dataloader_iter), False except StopIteration: return None, True """ @abstractmethod - def fetching_function(self) -> Generator: + def fetching_function(self) -> Any: """Override with your own fetching logic.""" + @abstractmethod + def prefetching(self, prefetch_batches: int) -> None: + """Override with your own pre-fetching logic.""" + def __init__( self, prefetch_batches: int = 0, @@ -63,6 +67,7 @@ def __init__( if prefetch_batches < 0: raise MisconfigurationException("`prefetch_batches` should at least be 0.") + self.store_on_device = False self.prefetch_batches = prefetch_batches + 1 self.dataloader: Optional[Iterable] = None @@ -192,6 +197,10 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: self.reset() self.dataloader_iter = iter(self.dataloader) self._apply_patch() + self.prefetching(self.prefetch_batches) + return self + + def __next__(self): return self.fetching_function() def reset(self) -> None: @@ -241,34 +250,38 @@ def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> No def wait(self) -> None: """Hook to override to indicate the `DataFetcher` to wait for an event.""" - def fetching_function(self) -> Generator: - self.done = False - while not self.done: - self._prefetching(self.prefetch_batches) - - while self.batches: - try: - yield_batch = self.pop_batch() - self._fetch_next_batch() - - # wait for batch to be available. - self.wait() - - # yield last and has next - yield (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False) - except StopIteration: - self.batches.insert(0, yield_batch) - break - - yield from self._consume_prefetched_batches() - - def _prefetching(self, prefetch_batches: int) -> None: + def prefetching(self, prefetch_batches: int) -> None: for _ in range(prefetch_batches): try: self._fetch_next_batch() except StopIteration: break + def fetching_function(self) -> Optional[Tuple[Any, bool]]: + if self.done: + while self.batches: + return self._get_queued_batch() + raise StopIteration + else: + try: + yield_batch = self.pop_batch() + self._fetch_next_batch() + + # wait for batch to be available. + self.wait() + + # yield last and has next + return yield_batch, False + # FIXME: Why does this count as a python `referrers` ? + # return (self.move_data_to_device(yield_batch) if not self.store_on_device else yield_batch, False) + except StopIteration: + self.batches.insert(0, yield_batch) + self.done = True + return self._get_queued_batch() + + except IndexError: + raise StopIteration + @contextmanager def apply_profiler(self, name: str) -> Generator: if self.profiler: @@ -291,13 +304,13 @@ def _consume_prefetched_batches(self) -> Generator: while self.batches: yield from self._yield_batch() - def _yield_batch(self) -> Generator: + def _get_queued_batch(self) -> Tuple[Any, bool]: self.wait() batch = self.batches.pop(0) if not self.store_on_device: batch = self.move_data_to_device(batch) is_last = len(self.batches) == 0 - yield batch, is_last + return batch, is_last def move_data_to_device(self, batch: Any) -> Any: if self.batch_to_device: @@ -405,8 +418,15 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: batch = batch.to(self.device) ... """ + def __init__(self): + super().__init__() + # prevent calling ``move_batch_to_device``` + self.store_on_device = True + + def prefetching(self, prefetch_batches: int) -> None: + self.iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) - def fetching_function(self) -> Generator: - iterator = iter(StepFuncDataLoaderIter(self.dataloader_iter, self)) + def fetching_function(self): while not self.done: - yield self.fetched, (iterator, self.done) + return self.fetched, (self.iterator, self.done) + raise StopIteration diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 038ac2c28ec8d..d7acd7e65727e 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -95,8 +95,7 @@ def validation_step(self, batch, batch_idx): upper = 201 * self.num_params * 4 current = torch.cuda.memory_allocated(0) assert lower < current - # FIXME: Where is the memory link coming from ? - assert current - initial_memory < upper + 3000 + assert current - initial_memory < upper return super().validation_step(batch, batch_idx) torch.cuda.empty_cache() diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index e012f6aa784d3..b351165e03fd8 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -191,7 +191,7 @@ def test_trainer_num_prefetch_batches(tmpdir): trainer.fit(model) t1 = time() global_step = trainer.global_step - assert isinstance(trainer.data_connector.data_fetcher, InterBatchParallelDataFetcher) + assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) torch.cuda.synchronize() @@ -201,7 +201,7 @@ def test_trainer_num_prefetch_batches(tmpdir): t3 = time() assert global_step == trainer.global_step == 4 - assert isinstance(trainer.data_connector.data_fetcher, DataFetcher) + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) ratio = (t3 - t2) / (t1 - t0) assert ratio > 1.1, ratio From 992016299d733c6e48fb00b96d5cb5336fb3f62c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 18:24:49 +0000 Subject: [PATCH 43/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- pytorch_lightning/utilities/fetching.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b66cd48e4ac95..8d337b972dce2 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -41,7 +41,7 @@ def __init__( ): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - + self.train_data_fetcher = train_data_fetcher self.validate_data_fetcher = validate_data_fetcher self.test_data_fetcher = test_data_fetcher diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index b45ddad593bff..72f54a891cde3 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -418,6 +418,7 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: batch = batch.to(self.device) ... """ + def __init__(self): super().__init__() # prevent calling ``move_batch_to_device``` From 899c34911bb920879cef946c9dca867c5834afc9 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 14:31:47 -0400 Subject: [PATCH 44/74] update on comments --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 3 ++- pytorch_lightning/loops/epoch/training_epoch_loop.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index efef057cec4a6..b0095770a278e 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -92,7 +92,8 @@ def advance( raise StopIteration if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device: - batch = self.trainer.accelerator.batch_to_device(batch) + with self.trainer.profiler.profile("evaluation_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index da44188acd162..741a05cd5701e 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -133,7 +133,8 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: _, (batch, is_last) = next(dataloader_iter) if not self.trainer.data_connector.train_data_fetcher.store_on_device: - batch = self.trainer.accelerator.batch_to_device(batch) + with self.trainer.profiler.profile("training_batch_to_device"): + batch = self.trainer.accelerator.batch_to_device(batch) # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END From 38fa806dc7cd082977d37a040a69bcac444023dc Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 14:43:50 -0400 Subject: [PATCH 45/74] update --- pytorch_lightning/loops/batch/training_batch_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 7d3ee7a7fa6e7..5792fdf3f1805 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -555,7 +555,6 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio step_kwargs = OrderedDict([("batch", batch)]) lightning_module = self.trainer.lightning_module - training_step_fx = getattr(self.trainer.lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "batch_idx"): From c2ec026c3e32674fa03b8413b907525e183b26ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Aug 2021 18:43:57 +0000 Subject: [PATCH 46/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_fetching.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 9929293081b86..c0eda726f1746 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from pytorch_lightning.callbacks.base import Callback from time import time from typing import Any, Iterator, Type from unittest import mock @@ -23,6 +22,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher @@ -182,7 +182,6 @@ def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() class CheckDataFetcher(Callback): - def __init__(self, data_fetcher_cls: Type): self.data_fetcher_cls = data_fetcher_cls From 71755caff28aec4af5b3f0d8d74cbb01e322bbba Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 23 Aug 2021 16:14:21 -0400 Subject: [PATCH 47/74] update --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ad1c69b0ba34b..beb2797ea70ed 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -454,7 +454,7 @@ def log( f" of {list(self._metric_attributes.values())}" ) - if is_param_in_hook_signature(self.training_step, "dataloader_iter") and batch_size is None: + if is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) and batch_size is None: raise MisconfigurationException( "When the `dataloader_iter` is requested within the `training_step`, `batch_size` should be provided." ) From 8bbb957b8b78731d69ff0a0be062bbac2b04da8f Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 04:47:06 -0400 Subject: [PATCH 48/74] update on comments --- pytorch_lightning/loops/base.py | 4 -- .../loops/batch/training_batch_loop.py | 5 --- .../loops/epoch/training_epoch_loop.py | 4 -- .../trainer/configuration_validator.py | 29 ++++++++++++ .../trainer/connectors/data_connector.py | 45 +++++-------------- .../logger_connector/logger_connector.py | 2 +- tests/utilities/test_fetching.py | 22 +++------ 7 files changed, 46 insertions(+), 65 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1f0c7f8d69e4c..ee5c3a1b708f1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -112,7 +112,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: self.on_advance_end() self.restarting = False except StopIteration: - self.on_stop_iteration() break output = self.on_run_end() @@ -161,9 +160,6 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" - def on_stop_iteration(self): - """Called when a stop iteration is being triggered.""" - def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict: """ The state dict is determined by the state and progress of this loop and all its children. diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 5792fdf3f1805..12b53d361140c 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -104,11 +104,6 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict: self.batch_outputs = None # free memory return output - def on_stop_iteration(self): - # if a stop iteration is being triggered within the ``training_step``, - # forward the expection to the parent loop. - raise StopIteration - def reset(self) -> None: """Resets the loop state""" self._hiddens = None diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index a6310ed3e75ac..c8161cbeada08 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -120,10 +120,6 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.batch_progress.increment_ready() - # ------------------------------------ - # TRAINING_STEP + TRAINING_STEP_END - # ------------------------------------ - with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 07548f9c49074..72d494349c7e5 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -16,6 +16,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature class ConfigValidator: @@ -34,6 +35,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, "val") self.__verify_manual_optimization_support(model) + self.__check_training_step_requires_dataloader_iter(model) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.__verify_eval_loop_configuration(model, "val") elif self.trainer.state.fn == TrainerFn.TESTING: @@ -128,3 +130,30 @@ def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> N f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`" " or switch to automatic optimization." ) + + def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule") -> bool: + """Check if the current `training_step` is requesting `dataloader_iter`.""" + training_step_fx = getattr(model, "training_step") + contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) + + if contains_dataloader_iter: + + if is_overridden("on_train_batch_start", model): + raise MisconfigurationException( + "The model hook `on_train_batch_start` is not compatible with " + "taking a `dataloader_iter` argument in your `training_step`." + ) + + if is_overridden("on_train_batch_end", model): + raise MisconfigurationException( + "The model hook `on_train_batch_end` is not compatible with " + "taking a `dataloader_iter` argument in your `training_step`." + ) + + if model.truncated_bptt_steps > 0: + raise MisconfigurationException( + "The model taking a `dataloader_iter` argument in your `training_step` " + "is incompatible with `truncated_bptt_steps > 0`." + ) + + return contains_dataloader_iter diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 5ab1051fb73cd..d722b128858cb 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -95,12 +95,14 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: if self.trainer.sanity_checking: return DataFetcher() - if self.trainer.training and self._check_training_step_requires_dataloader_iter(): + training_step_fx = getattr(self.trainer.lightning_module, "training_step") + if self.trainer.training and is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher() + elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": # note: this is an experimental feature if not self.trainer.training_type_plugin.on_gpu: @@ -246,39 +248,14 @@ def detach_data(model: "pl.LightningModule") -> None: loader.unpatch(model) def teardown(self) -> None: - items = list(vars(self).items()) - for attr_name, attr in items: - if isinstance(attr, AbstractDataFetcher): - attr.teardown() - delattr(self, attr_name) - - def _check_training_step_requires_dataloader_iter(self) -> bool: - """Check if the current `training_step` is requesting `dataloader_iter`.""" - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) - - if contains_dataloader_iter: - - if is_overridden("on_train_batch_start", self.trainer.lightning_module): - raise MisconfigurationException( - "The model hook `on_train_batch_start` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - - if is_overridden("on_train_batch_end", self.trainer.lightning_module): - raise MisconfigurationException( - "The model hook `on_train_batch_end` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - - if self.trainer.lightning_module.truncated_bptt_steps > 0: - raise MisconfigurationException( - "The model taking a `dataloader_iter` argument in your `training_step` " - "is incompatible with `truncated_bptt_steps > 0`." - ) - - return contains_dataloader_iter - + if self.train_data_fetcher: + self.train_data_fetcher.teardown() + if self.validate_data_fetcher: + self.validate_data_fetcher.teardown() + if self.test_data_fetcher: + self.test_data_fetcher.teardown() + if self.sanity_check_data_fetcher: + self.sanity_check_data_fetcher.teardown() class _PatchDataLoader: r""" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 7151a68362e4f..12ed72df6bda3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -200,7 +200,7 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT: def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None: # when the user request `dataloader_iter`, we can't track the batch_size - # and this is left to user responsability. + # and this is left to user responsibility. if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher): self.trainer._results.extract_batch_size(split_batch) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index c0eda726f1746..008374d4badac 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -181,13 +181,6 @@ def test_trainer_num_prefetch_batches(tmpdir): model = RecommenderModel() - class CheckDataFetcher(Callback): - def __init__(self, data_fetcher_cls: Type): - self.data_fetcher_cls = data_fetcher_cls - - def on_batch_start(self, trainer, *_) -> None: - assert isinstance(trainer.data_connector.train_data_fetcher, self.data_fetcher_cls) - trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=1, @@ -199,18 +192,18 @@ def on_batch_start(self, trainer, *_) -> None: with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}): t0 = time() - trainer_kwargs["callbacks"] = CheckDataFetcher(InterBatchParallelDataFetcher) trainer = Trainer(**trainer_kwargs) trainer.fit(model) + assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) t1 = time() global_step = trainer.global_step torch.cuda.synchronize() t2 = time() - trainer_kwargs["callbacks"] = CheckDataFetcher(DataFetcher) trainer = Trainer(**trainer_kwargs) trainer.fit(model) + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) t3 = time() assert global_step == trainer.global_step == 4 @@ -260,9 +253,6 @@ def training_epoch_end(self, *_): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) - # should be cleaned out ! - assert not hasattr(trainer.data_connector, "train_data_fetcher") - class DummyWaitable: def __init__(self, val: Any) -> None: @@ -331,11 +321,9 @@ def test_stop_iteration(tmpdir) -> None: EXPECT_NUM_BATCHES_PROCESSED = 2 class TestModel(AsyncBoringModel): - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - output = super().training_step(dataloader_iter) - if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED: - raise StopIteration - return output + + def train_dataloader(self): + return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED)) trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) m = TestModel() From 5ba95ba3161d6024f28f083978eac35cb211615c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Aug 2021 08:48:16 +0000 Subject: [PATCH 49/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/data_connector.py | 1 + tests/utilities/test_fetching.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index d722b128858cb..912da485f6202 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -257,6 +257,7 @@ def teardown(self) -> None: if self.sanity_check_data_fetcher: self.sanity_check_data_fetcher.teardown() + class _PatchDataLoader: r""" Callable object for patching dataloaders passed into trainer.fit(). diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 008374d4badac..9ffa1139fe527 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -321,7 +321,6 @@ def test_stop_iteration(tmpdir) -> None: EXPECT_NUM_BATCHES_PROCESSED = 2 class TestModel(AsyncBoringModel): - def train_dataloader(self): return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED)) From 63c11e1066ccc2270df0341e1cdfb527d95dd233 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 10:05:16 +0100 Subject: [PATCH 50/74] update --- pytorch_lightning/loops/batch/training_batch_loop.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 12b53d361140c..887887c745512 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -547,13 +547,14 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("batch", batch)]) + step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) lightning_module = self.trainer.lightning_module training_step_fx = getattr(self.trainer.lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx"): - step_kwargs["batch_idx"] = batch_idx + if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): + if not is_param_in_hook_signature(training_step_fx, "batch_idx", explicit=True): + del step_kwargs["batch_idx"] if len(self.trainer.optimizers) > 1: has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") From fe92512af6e5f798a222cac8adef759cecb0e763 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 05:28:48 -0400 Subject: [PATCH 51/74] update --- pytorch_lightning/loops/base.py | 2 +- pytorch_lightning/loops/fit_loop.py | 4 ++-- pytorch_lightning/trainer/connectors/data_connector.py | 4 ++-- tests/core/test_metric_result_integration.py | 4 ++++ 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index ee5c3a1b708f1..4114b287c3266 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -237,4 +237,4 @@ def _load_from_state_dict( v.reset(metrics=False) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) - self.restarting = True + self.restarting = True \ No newline at end of file diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8507809c361ec..8f6b79e17b343 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Iterator, Optional +from typing import Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop @@ -193,7 +193,7 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, batch_idx=self.batch_idx + 1) dataloader_iter = iter(dataloader) with self.trainer.profiler.profile("run_training_epoch"): diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 912da485f6202..27becf6a91d6a 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -111,7 +111,7 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: return DataFetcher() - def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: + def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0, batch_idx: int = 0) -> Iterable: stage: str = self.trainer.state.stage.value data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() data_fetcher.setup( @@ -123,7 +123,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) setattr(self, f"{stage}_data_fetcher", data_fetcher) if isinstance(data_fetcher, DataLoaderIterDataFetcher): return data_fetcher - return enumerate(data_fetcher) + return enumerate(data_fetcher, batch_idx) def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 7c6a985d09f33..d4b2e06b7a324 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -411,6 +411,10 @@ def training_step(self, batch, batch_idx): # However, below we will simulate a failure on `batch_idx=3`. if self.trainer.fit_loop.restarting: + + print() + print(self.results, batch_idx) + self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) From 5a738de1136c77b4c96ee02266b357c384d512fb Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 05:29:55 -0400 Subject: [PATCH 52/74] update --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 3c24c921da389..32246cbf6c8cb 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -100,7 +100,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader( - dataloader, dataloader_idx=self.current_dataloader_idx + dataloader, + dataloader_idx=self.current_dataloader_idx, + batch_idx=self.epoch_loop.batch_progress.current.ready + 1 ) dataloader_iter = iter(dataloader) From a4a21715e3779ff5e4099091a06f5605ef4da82a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Aug 2021 09:31:04 +0000 Subject: [PATCH 53/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/base.py | 2 +- pytorch_lightning/loops/dataloader/evaluation_loop.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 4114b287c3266..ee5c3a1b708f1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -237,4 +237,4 @@ def _load_from_state_dict( v.reset(metrics=False) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) - self.restarting = True \ No newline at end of file + self.restarting = True diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 32246cbf6c8cb..d63dbf05e42a6 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -100,9 +100,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader( - dataloader, + dataloader, dataloader_idx=self.current_dataloader_idx, - batch_idx=self.epoch_loop.batch_progress.current.ready + 1 + batch_idx=self.epoch_loop.batch_progress.current.ready + 1, ) dataloader_iter = iter(dataloader) From 29a72443dbee4f0297d894a0ba7b6539682a6cdc Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 10:35:45 +0100 Subject: [PATCH 54/74] improve test --- tests/utilities/test_fetching.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 9ffa1139fe527..36a964ffb2b60 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -13,7 +13,7 @@ # limitations under the License. import os from time import time -from typing import Any, Iterator, Type +from typing import Any, Iterator from unittest import mock import pytest @@ -22,13 +22,11 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning import Trainer -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher from pytorch_lightning.utilities.types import STEP_OUTPUT from tests.helpers import BoringModel, RandomDataset -from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -313,23 +311,37 @@ def test_training_step_with_dataloader_access(tmpdir) -> None: assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed." -def test_stop_iteration(tmpdir) -> None: +@pytest.mark.parametrize("trigger_stop_iteration", [False, True]) +def test_stop_iteration(trigger_stop_iteration, tmpdir): """ - Verify that when `StopIteration` is raised within `training_step`, `fit()` - terminiates as expected. + Verify that StopIteration properly terminates the training when this is trigged + from the current `dataloader_iter` """ EXPECT_NUM_BATCHES_PROCESSED = 2 class TestModel(AsyncBoringModel): + def __init__(self, trigger_stop_iteration) -> None: + super().__init__() + self.trigger_stop_iteration = trigger_stop_iteration + + def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> STEP_OUTPUT: + output = super().training_step(dataloader_iter) + if self.trigger_stop_iteration and batch_idx == EXPECT_NUM_BATCHES_PROCESSED: + raise StopIteration + return output + def train_dataloader(self): + if self.trigger_stop_iteration: + return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED)) return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED)) trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = TestModel() + m = TestModel(trigger_stop_iteration) trainer.fit(m) - assert ( - m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED - ), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed." + expected = EXPECT_NUM_BATCHES_PROCESSED + if trigger_stop_iteration: + expected *= 2 + assert m.num_batches_processed == expected def test_on_train_batch_start_overridden(tmpdir) -> None: From 985e47a4636714ac44c11b4e2f665d59c67e79b9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 10:38:04 +0100 Subject: [PATCH 55/74] update --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- tests/utilities/test_fetching.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 887887c745512..55bc0343ab6a1 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -553,7 +553,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio training_step_fx = getattr(self.trainer.lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): - if not is_param_in_hook_signature(training_step_fx, "batch_idx", explicit=True): + if not is_param_in_hook_signature(training_step_fx, "batch_idx"): del step_kwargs["batch_idx"] if len(self.trainer.optimizers) > 1: diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 36a964ffb2b60..88dda1791160c 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -324,9 +324,9 @@ def __init__(self, trigger_stop_iteration) -> None: super().__init__() self.trigger_stop_iteration = trigger_stop_iteration - def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> STEP_OUTPUT: + def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT: output = super().training_step(dataloader_iter) - if self.trigger_stop_iteration and batch_idx == EXPECT_NUM_BATCHES_PROCESSED: + if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED: raise StopIteration return output From 95726a896fcacb71f0cb5ca8d325a65efb31472b Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 10:39:33 +0100 Subject: [PATCH 56/74] cleanup --- pytorch_lightning/trainer/configuration_validator.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 72d494349c7e5..d9c341c5dfaeb 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -131,12 +131,10 @@ def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> N " or switch to automatic optimization." ) - def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule") -> bool: + def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule"): """Check if the current `training_step` is requesting `dataloader_iter`.""" training_step_fx = getattr(model, "training_step") - contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True) - - if contains_dataloader_iter: + if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): if is_overridden("on_train_batch_start", model): raise MisconfigurationException( @@ -155,5 +153,3 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) - - return contains_dataloader_iter From fa75fb2ef8242c8f630ec1c188f1c2e8f7c66697 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 10:40:18 +0100 Subject: [PATCH 57/74] more cleanup --- tests/core/test_metric_result_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index d4b2e06b7a324..c5a4d2f65a863 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -412,9 +412,6 @@ def training_step(self, batch, batch_idx): if self.trainer.fit_loop.restarting: - print() - print(self.results, batch_idx) - self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) From 0f74e4fa462bcbc618297d2db68f0aafe149e639 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 07:02:25 -0400 Subject: [PATCH 58/74] resolve batch_idx sync --- .../loops/dataloader/evaluation_loop.py | 11 ++++------- .../loops/epoch/evaluation_epoch_loop.py | 7 ++++--- .../loops/epoch/training_epoch_loop.py | 15 ++++++++++++--- pytorch_lightning/loops/fit_loop.py | 2 +- .../trainer/connectors/data_connector.py | 6 ++---- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d63dbf05e42a6..a2a2350138ea4 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -98,18 +98,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) + dataloader_idx: int = self.current_dataloader_idx dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader( - dataloader, - dataloader_idx=self.current_dataloader_idx, - batch_idx=self.epoch_loop.batch_progress.current.ready + 1, - ) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) dataloader_iter = iter(dataloader) - dl_max_batches = self._max_batches[self.current_dataloader_idx] + dl_max_batches = self._max_batches[dataloader_idx] dl_outputs = self.epoch_loop.run( - dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders + dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders ) # store batch level output per dataloader diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index eb3f9dad58bcf..abe0f1ce4f867 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -17,7 +17,6 @@ from deprecate import void from torch import Tensor - from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.memory import recursive_detach @@ -70,6 +69,8 @@ def on_run_start( self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders + self.dataloader_iter = enumerate(dataloader_iter, self.batch_progress.current.ready) + def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int ) -> None: @@ -86,7 +87,7 @@ def advance( """ void(dl_max_batches, num_dataloaders) - batch_idx, (batch, _) = next(dataloader_iter) + batch_idx, (batch, _) = next(self.dataloader_iter) if batch is None: raise StopIteration @@ -229,4 +230,4 @@ def _track_output_for_epoch_end( elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu: output = output.cpu() outputs.append(output) - return outputs + return outputs \ No newline at end of file diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c8161cbeada08..26da89b7d4322 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import torch @@ -96,14 +97,16 @@ def reset(self) -> None: self.scheduler_progress.current.reset() self.batch_loop.optim_progress.reset_on_epoch() - def on_run_start(self, *args: Any, **kwargs: Any) -> None: + def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # hook self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") self.trainer.fit_loop.epoch_progress.increment_started() + + self._prepare_dataloader_iter(dataloader_iter) - def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: + def advance(self, *args: Any, **kwargs: Any) -> None: """Runs a single training batch. Args: @@ -112,7 +115,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: Raises: StopIteration: When the epoch is canceled by the user returning -1 """ - batch_idx, (batch, is_last) = next(dataloader_iter) + batch_idx, (batch, is_last) = next(self.dataloader_iter) if not self.trainer.data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): @@ -382,3 +385,9 @@ def _save_loggers_on_train_batch_end(self) -> None: should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() + + def _prepare_dataloader_iter(self, dataloader_iter: AbstractDataFetcher) -> None: + if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): + dataloader_iter = enumerate(dataloader_iter, self.batch_idx + 1) + # restore iteration + self.dataloader_iter = dataloader_iter \ No newline at end of file diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8f6b79e17b343..49af10d4b2c0d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -193,7 +193,7 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, batch_idx=self.batch_idx + 1) + dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) dataloader_iter = iter(dataloader) with self.trainer.profiler.profile("run_training_epoch"): diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 27becf6a91d6a..cdf4ce1c471bf 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -111,7 +111,7 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: return DataFetcher() - def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0, batch_idx: int = 0) -> Iterable: + def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: stage: str = self.trainer.state.stage.value data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() data_fetcher.setup( @@ -121,9 +121,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0, profiler=self.trainer.profiler, ) setattr(self, f"{stage}_data_fetcher", data_fetcher) - if isinstance(data_fetcher, DataLoaderIterDataFetcher): - return data_fetcher - return enumerate(data_fetcher, batch_idx) + return data_fetcher def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 From 6667588d410a4db882048f55b004387b576d1b87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Aug 2021 11:04:06 +0000 Subject: [PATCH 59/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 4 +--- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 3 ++- pytorch_lightning/loops/epoch/training_epoch_loop.py | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index a2a2350138ea4..730a1944c6e93 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -105,9 +105,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dl_max_batches = self._max_batches[dataloader_idx] - dl_outputs = self.epoch_loop.run( - dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders - ) + dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders) # store batch level output per dataloader if self.should_track_batch_outputs_for_epoch_end: diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index abe0f1ce4f867..3753ca3e2737f 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -17,6 +17,7 @@ from deprecate import void from torch import Tensor + from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.memory import recursive_detach @@ -230,4 +231,4 @@ def _track_output_for_epoch_end( elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu: output = output.cpu() outputs.append(output) - return outputs \ No newline at end of file + return outputs diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 26da89b7d4322..84faf02a2f4aa 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import torch @@ -21,6 +20,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -103,7 +103,7 @@ def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") self.trainer.fit_loop.epoch_progress.increment_started() - + self._prepare_dataloader_iter(dataloader_iter) def advance(self, *args: Any, **kwargs: Any) -> None: @@ -390,4 +390,4 @@ def _prepare_dataloader_iter(self, dataloader_iter: AbstractDataFetcher) -> None if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): dataloader_iter = enumerate(dataloader_iter, self.batch_idx + 1) # restore iteration - self.dataloader_iter = dataloader_iter \ No newline at end of file + self.dataloader_iter = dataloader_iter From 570edffc5d086ee28429fc877856ab91fe4e9e22 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 07:28:38 -0400 Subject: [PATCH 60/74] drop weird test --- tests/checkpointing/test_trainer_checkpoint.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index f76e76b2f9dd9..f9ef2fd5c20d7 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -75,9 +75,6 @@ def validation_step(self, batch, batch_idx): results.append(deepcopy(trainer.callback_metrics)) best_model_paths.append(trainer.checkpoint_callback.best_model_path) - for idx in range(len(results) - 1): - assert results[idx]["val_loss"] > results[idx + 1]["val_loss"] - for idx, best_model_path in enumerate(best_model_paths): if idx == 0: assert best_model_path.endswith(f"epoch=0{idx}.ckpt") From 89f372ac4b71922381ac50008716d145b562ed3b Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 10:09:16 -0400 Subject: [PATCH 61/74] update --- tests/trainer/optimization/test_manual_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 670e8b4842a89..80896f6fa450c 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -419,7 +419,7 @@ class TestModel(ManualOptModel): called = False - def on_after_backward(self): + def on_before_optimizer_step(self, *args): self.called = True norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) if not (torch.isinf(norm) or torch.isnan(norm)): From 3fddde7e4c8c090cd11be65bb55dd2e289ba5db1 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 10:19:53 -0400 Subject: [PATCH 62/74] update on comments --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 1 + tests/core/test_metric_result_integration.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 3753ca3e2737f..ec3d9901f6dbb 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -37,6 +37,7 @@ def __init__(self) -> None: self._num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] self.batch_progress = Progress() + self.dataloader_iter: Optional[Iterator] = None @property def done(self) -> bool: diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index c5a4d2f65a863..7c6a985d09f33 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -411,7 +411,6 @@ def training_step(self, batch, batch_idx): # However, below we will simulate a failure on `batch_idx=3`. if self.trainer.fit_loop.restarting: - self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) From cd1c8c16e74a3513bb10dd27bbc873065943e6f9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 17:26:16 +0100 Subject: [PATCH 63/74] resolve typing --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 2 +- pytorch_lightning/loops/epoch/training_epoch_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 -- pytorch_lightning/utilities/fetching.py | 1 - 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 730a1944c6e93..7f06f5cd4ff63 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 84faf02a2f4aa..9a960e44e5705 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Union import torch diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c445c638f8126..bec516f2d8ff5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -77,12 +77,10 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS log = logging.getLogger(__name__) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 2d646a7d230e0..d37cd3a9c1e6f 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -23,7 +23,6 @@ from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( From 3fb109db53073afad4dd8e0193040f69a6c1fe27 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 12:43:50 -0400 Subject: [PATCH 64/74] updte --- .../processors/iterator_batch_processor.py | 174 ------------------ 1 file changed, 174 deletions(-) delete mode 100644 pytorch_lightning/loops/processors/iterator_batch_processor.py diff --git a/pytorch_lightning/loops/processors/iterator_batch_processor.py b/pytorch_lightning/loops/processors/iterator_batch_processor.py deleted file mode 100644 index c1981173215ae..0000000000000 --- a/pytorch_lightning/loops/processors/iterator_batch_processor.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Tuple - -import torch - -import pytorch_lightning as pl -from pytorch_lightning.loops.utilities import ( - _check_training_step_output, - _process_training_step_output, - check_finite_loss, -) -from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AttributeDict -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature - -log = logging.getLogger(__name__) - - -class IteratorBatchProcessor: - """ - The processor for performing a training iteration when ``training_step`` needs access to the - dataloader. It is selected when the signature of ``training_step`` contains ``dataloader_iter``: - - def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: - - The ``training_step`` is allowed to fetch multiple batches during one training iteration. The - framework provides minimum amount of automation with regards to model optimization. The - flexibility allows for ease of experimentation with inter-batch parallelism techniques. - - This processor doesn't support ``automatic_optimization`` and ``tbptt``. An error will be thrown - if the ``LightningModule`` or the ``Trainer`` is configured to use these features. - - The ``training_step`` is responsible for reporting whether it has reached the last batch by - including an ``is_last`` field in the result dict. Failing to do so will result in an error. - - The ``training_step`` should only optimize the model with one batch for the sake of API and - reporting consistency (TODO: consider removing this limitation). - - Args: - trainer: a reference to the trainer - model: a reference to the lightning module (for config validation purposes only) - """ - - def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: - if is_overridden("on_train_batch_start", model): - raise MisconfigurationException( - "The model hook `on_train_batch_start` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("on_train_batch_end", model): - raise MisconfigurationException( - "The model hook `on_train_batch_end` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if is_overridden("tbptt_split_batch", model): - raise MisconfigurationException( - "The model hook `tbptt_split_batch` is not compatible with " - "taking a `dataloader_iter` argument in your `training_step`." - ) - if trainer.accumulate_grad_batches != 1: - raise MisconfigurationException( - "`accumulate_grad_batches` can only be 1 when your " - "`training_step` takes `dataloader_iter` as an argument." - ) - - self.trainer = trainer - - # The following field is not used by the processor since it doesn't support automatic - # optimization and tbptt. Initializing them regardless since they are currently expected by - # `FitLoop` or `TrainingEpochLoop`. - # TODO: come up with an abstraction for "batch processors" so they can be better decoupled - # with parent loops. - self.accumulated_loss: Optional[torch.Tensor] = None - self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1) - self.optim_progress = OptimizationProgress() - self.split_idx: int = 0 - self._skip_backward = False - - def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: - """ - Returns the number of active optimizers. - """ - return len(self.trainer.optimizers) - - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]: - """ - Returns the currently active optimizers. - - Returns: - A list of tuples (opt_idx, optimizer) of currently active optimizers. - """ - return list(enumerate(self.trainer.optimizers)) - - def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]: - """ - Args: - dataloader_iter: the iterator over the dataloader producing the new batch - """ - batch_idx, (dataloader_iter, is_last) = next(dataloader_iter) - - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() - - # give the PL module a result for logging - model = self.trainer.lightning_module - # manually capture logged metrics - model._current_fx_name = "training_step" - step_kwargs = self._build_kwargs(dataloader_iter, batch_idx) - - with self.trainer.profiler.profile("model_forward"): - with self.trainer.profiler.profile("training_step"): - training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - _check_training_step_output(self.trainer.lightning_module, training_step_output) - - training_step_output, _ = _process_training_step_output(self.trainer, training_step_output) - - if self.trainer.terminate_on_nan: - check_finite_loss(self.trainer.lightning_module, training_step_output.minimize) - - batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] - if training_step_output: - batch_outputs[0].append(training_step_output) - - return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last) - - def teardown(self) -> None: - """ - No-op. Only defined to comply with FitLoop's expectation. - """ - pass - - # FIXME: To be deleted in next PR. - def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]: - """Builds the keyword arguments for training_step - - Args: - dataloader_iter: The dataloader to pass - batch_idx: the index of the current batch - - Returns: - An ordered dict with the keyword arguments for the training step - """ - # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)]) - - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx"): - step_kwargs["batch_idx"] = batch_idx - - return step_kwargs From c00278da200cb656d3dfd2f993ee22bf4eb9368a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 13:08:40 -0400 Subject: [PATCH 65/74] update on comments --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/loops/batch/training_batch_loop.py | 9 ++++----- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index beb2797ea70ed..b0c55529a623c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -456,7 +456,7 @@ def log( if is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) and batch_size is None: raise MisconfigurationException( - "When the `dataloader_iter` is requested within the `training_step`, `batch_size` should be provided." + "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." ) results.log( diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 55bc0343ab6a1..8f1c6ffa48288 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -547,14 +547,13 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) + step_kwargs = OrderedDict([("batch", batch)]) lightning_module = self.trainer.lightning_module - training_step_fx = getattr(self.trainer.lightning_module, "training_step") + training_step_fx = getattr(lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): - if not is_param_in_hook_signature(training_step_fx, "batch_idx"): - del step_kwargs["batch_idx"] + if is_param_in_hook_signature(training_step_fx, "batch_idx"): + step_kwargs["batch_idx"] = batch_idx if len(self.trainer.optimizers) > 1: has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index ec3d9901f6dbb..690a1461c134f 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -67,7 +67,7 @@ def on_run_start( dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders """ - void(dataloader_iter, dataloader_idx) + void(dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders @@ -87,7 +87,7 @@ def advance( Raises: StopIteration: If the current batch is None """ - void(dl_max_batches, num_dataloaders) + void(dataloader_iter, dl_max_batches, num_dataloaders) batch_idx, (batch, _) = next(self.dataloader_iter) From edaf099cb69f208173700cd76a162863f1979ca9 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 13:22:35 -0400 Subject: [PATCH 66/74] update --- .../loops/epoch/evaluation_epoch_loop.py | 3 ++- pytorch_lightning/loops/epoch/training_epoch_loop.py | 10 ++++++---- pytorch_lightning/loops/utilities.py | 11 ++++++++++- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 690a1461c134f..c46d252d8fbdc 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -22,6 +22,7 @@ from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter class EvaluationEpochLoop(Loop): @@ -71,7 +72,7 @@ def on_run_start( self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders - self.dataloader_iter = enumerate(dataloader_iter, self.batch_progress.current.ready) + _prepare_dataloader_iter(self, dataloader_iter, self.batch_progress.current.ready + 1) def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 9a960e44e5705..37953156910e7 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.loops.base import Loop from typing import Any, Dict, Iterator, List, Optional, Union import torch @@ -23,6 +24,7 @@ from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter class TrainingEpochLoop(loops.Loop): @@ -104,7 +106,7 @@ def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.trainer.call_hook("on_train_epoch_start") self.trainer.fit_loop.epoch_progress.increment_started() - self._prepare_dataloader_iter(dataloader_iter) + _prepare_dataloader_iter(self, dataloader_iter, self.batch_idx + 1) def advance(self, *args: Any, **kwargs: Any) -> None: """Runs a single training batch. @@ -386,8 +388,8 @@ def _save_loggers_on_train_batch_end(self) -> None: if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - def _prepare_dataloader_iter(self, dataloader_iter: AbstractDataFetcher) -> None: + def _prepare_dataloader_iter(loop: Loop, dataloader_iter: AbstractDataFetcher, batch_idx: int) -> None: if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): - dataloader_iter = enumerate(dataloader_iter, self.batch_idx + 1) + dataloader_iter = enumerate(dataloader_iter, batch_idx) # restore iteration - self.dataloader_iter = dataloader_iter + loop.dataloader_iter = dataloader_iter diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 1f0636d8b6cda..e7c04fcf527c5 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping, Optional, Tuple +from typing import Mapping, Optional, Tuple, Iterator import torch @@ -22,6 +22,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher def check_finite_loss(model: "pl.LightningModule", loss: torch.Tensor) -> None: @@ -102,3 +104,10 @@ def _process_training_step_output( if trainer.move_metrics_to_cpu: results.cpu() return results, hiddens + + +def _prepare_dataloader_iter(loop: Loop, dataloader_iter: Iterator, batch_idx: int) -> None: + if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): + dataloader_iter = enumerate(dataloader_iter, batch_idx) + # restore iteration + loop.dataloader_iter = dataloader_iter \ No newline at end of file From 9706854ee2345abf83d5ff5799a1011bb519eb2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Aug 2021 17:23:45 +0000 Subject: [PATCH 67/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 2 +- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- pytorch_lightning/loops/utilities.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index c46d252d8fbdc..85c3fb245e48a 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -19,10 +19,10 @@ from torch import Tensor from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.loops.utilities import _prepare_dataloader_iter class EvaluationEpochLoop(Loop): diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 37953156910e7..ad4700ccb6500 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,20 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.loops.base import Loop from typing import Any, Dict, Iterator, List, Optional, Union import torch from pytorch_lightning import loops # import as loops to avoid circular imports +from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.batch import TrainingBatchLoop +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.loops.utilities import _prepare_dataloader_iter class TrainingEpochLoop(loops.Loop): diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index e7c04fcf527c5..a73865abb3eaa 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping, Optional, Tuple, Iterator +from typing import Iterator, Mapping, Optional, Tuple import torch import pytorch_lightning as pl +from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.loops.base import Loop -from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher def check_finite_loss(model: "pl.LightningModule", loss: torch.Tensor) -> None: @@ -110,4 +110,4 @@ def _prepare_dataloader_iter(loop: Loop, dataloader_iter: Iterator, batch_idx: i if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): dataloader_iter = enumerate(dataloader_iter, batch_idx) # restore iteration - loop.dataloader_iter = dataloader_iter \ No newline at end of file + loop.dataloader_iter = dataloader_iter From 1ad63bca504e2ac186ef9a6aea058be776b5c996 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 18:24:47 +0100 Subject: [PATCH 68/74] resolve comments --- .../loops/epoch/evaluation_epoch_loop.py | 4 ++-- pytorch_lightning/loops/epoch/training_epoch_loop.py | 12 ++---------- pytorch_lightning/loops/utilities.py | 10 +++++----- 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index c46d252d8fbdc..0fee360eb1504 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -19,10 +19,10 @@ from torch import Tensor from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.loops.utilities import _prepare_dataloader_iter class EvaluationEpochLoop(Loop): @@ -72,7 +72,7 @@ def on_run_start( self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders - _prepare_dataloader_iter(self, dataloader_iter, self.batch_progress.current.ready + 1) + self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready + 1) def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 37953156910e7..43d51fe0027c6 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,20 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.loops.base import Loop from typing import Any, Dict, Iterator, List, Optional, Union import torch from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop +from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.loops.utilities import _prepare_dataloader_iter class TrainingEpochLoop(loops.Loop): @@ -106,7 +104,7 @@ def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.trainer.call_hook("on_train_epoch_start") self.trainer.fit_loop.epoch_progress.increment_started() - _prepare_dataloader_iter(self, dataloader_iter, self.batch_idx + 1) + self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_idx + 1) def advance(self, *args: Any, **kwargs: Any) -> None: """Runs a single training batch. @@ -387,9 +385,3 @@ def _save_loggers_on_train_batch_end(self) -> None: should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - - def _prepare_dataloader_iter(loop: Loop, dataloader_iter: AbstractDataFetcher, batch_idx: int) -> None: - if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): - dataloader_iter = enumerate(dataloader_iter, batch_idx) - # restore iteration - loop.dataloader_iter = dataloader_iter diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index e7c04fcf527c5..89ba5cd07d459 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping, Optional, Tuple, Iterator +from typing import Iterator, Mapping, Optional, Tuple import torch @@ -20,10 +20,9 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.loops.base import Loop -from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher def check_finite_loss(model: "pl.LightningModule", loss: torch.Tensor) -> None: @@ -106,8 +105,9 @@ def _process_training_step_output( return results, hiddens -def _prepare_dataloader_iter(loop: Loop, dataloader_iter: Iterator, batch_idx: int) -> None: +def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator: + """Attach the dataloader""" if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): dataloader_iter = enumerate(dataloader_iter, batch_idx) # restore iteration - loop.dataloader_iter = dataloader_iter \ No newline at end of file + return dataloader_iter From 496f4199d65593714ed1f4b9dc620866e5398158 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 18:35:16 +0100 Subject: [PATCH 69/74] remove + 1 --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 0fee360eb1504..e4770084c84cd 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -72,7 +72,7 @@ def on_run_start( self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders - self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready + 1) + self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready) def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int From 28e855a6b40653ca64f64de0aaa46b439cdb3f76 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Tue, 24 Aug 2021 13:38:09 -0400 Subject: [PATCH 70/74] update on comments --- pytorch_lightning/core/lightning.py | 2 +- tests/utilities/test_fetching.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b0c55529a623c..fd9f0be4c3bb0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -454,7 +454,7 @@ def log( f" of {list(self._metric_attributes.values())}" ) - if is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) and batch_size is None: + if self.trainer.training and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) and batch_size is None: raise MisconfigurationException( "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." ) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 88dda1791160c..eda4a6506270d 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -192,8 +192,8 @@ def test_trainer_num_prefetch_batches(tmpdir): t0 = time() trainer = Trainer(**trainer_kwargs) trainer.fit(model) - assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) t1 = time() + assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher) global_step = trainer.global_step torch.cuda.synchronize() @@ -201,8 +201,8 @@ def test_trainer_num_prefetch_batches(tmpdir): t2 = time() trainer = Trainer(**trainer_kwargs) trainer.fit(model) - assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) t3 = time() + assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher) assert global_step == trainer.global_step == 4 ratio = (t3 - t2) / (t1 - t0) From 240a7b78b856f48a020f2a608c063dccbb6e8f3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Aug 2021 17:39:26 +0000 Subject: [PATCH 71/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fd9f0be4c3bb0..4a21365614f49 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -454,7 +454,11 @@ def log( f" of {list(self._metric_attributes.values())}" ) - if self.trainer.training and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) and batch_size is None: + if ( + self.trainer.training + and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True) + and batch_size is None + ): raise MisconfigurationException( "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." ) From 33f13233bef861f27f685ce32e4d313b71e4c609 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 24 Aug 2021 19:42:32 +0200 Subject: [PATCH 72/74] No longer need to seed --- tests/checkpointing/test_trainer_checkpoint.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index f9ef2fd5c20d7..6a8192ef0149e 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -17,7 +17,7 @@ import torch import pytorch_lightning as pl -from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from tests.helpers import BoringModel @@ -27,8 +27,6 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir): This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test """ - seed_everything(4) - checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) class ExtendedBoringModel(BoringModel): From 5b9d81ea55ab7c42de367aecd8dde7fb4de881e5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 19:06:48 +0100 Subject: [PATCH 73/74] update --- .../loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/utilities/signature_utils.py | 13 ++++++++++--- tests/utilities/test_fetching.py | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 8f1c6ffa48288..d77918c2405bc 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -552,7 +552,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio lightning_module = self.trainer.lightning_module training_step_fx = getattr(lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx"): + if is_param_in_hook_signature(training_step_fx, "batch_idx", min_size=2): step_kwargs["batch_idx"] = batch_idx if len(self.trainer.optimizers) > 1: diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py index 5c7e468d84738..05045e98d3af6 100644 --- a/pytorch_lightning/utilities/signature_utils.py +++ b/pytorch_lightning/utilities/signature_utils.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Callable +from typing import Callable, Optional -def is_param_in_hook_signature(hook_fx: Callable, param: str, explicit: bool = False) -> bool: +def is_param_in_hook_signature( + hook_fx: Callable, param: str, explicit: bool = False, min_args: Optional[int] = None +) -> bool: """ Args: hook_fx: the hook callable param: the name of the parameter to check explicit: whether the parameter has to be explicitly declared + min_args: whether the `signature` as at least `min_args` parameters """ hook_params = list(inspect.signature(hook_fx).parameters) - return param in hook_params or (not explicit and "args" in hook_params) + return ( + param in hook_params + or (not explicit and "args" in hook_params) + or (isinstance(min_args, int) and len(hook_params) >= min_args) + ) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index eda4a6506270d..bf54bbae83568 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -231,7 +231,7 @@ def training_step(self, dataloader_iter, batch_idx): self.count += 2 if self.automatic_optimization: loss = super().training_step(batch, 0) - with pytest.raises(MisconfigurationException, match="`batch_size` should be provided"): + with pytest.raises(MisconfigurationException, match="dataloader_iter"): self.log("train_loss", loss["loss"]) self.log("train_loss", loss["loss"], batch_size=1) else: From 2c44e7bb17ca5f87e006706315515b6be0fdb9ff Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 24 Aug 2021 19:07:35 +0100 Subject: [PATCH 74/74] typo --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index d77918c2405bc..29517ad306eba 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -552,7 +552,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio lightning_module = self.trainer.lightning_module training_step_fx = getattr(lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "batch_idx", min_size=2): + if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): step_kwargs["batch_idx"] = batch_idx if len(self.trainer.optimizers) > 1: