From a636d894a45fc249ab12d532dade858e478b6645 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Wed, 26 Oct 2022 15:48:45 -0700 Subject: [PATCH] Support recreating dataloaders during loop (#248) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/248 Sometimes users need to recreate the dataloaders at the start of the epoch during the overall training loop. To support this, we allow users to register a creation function when creating the state for training or fitting. We pass the state as an argument since users may want to use progress information such as the progress counters to reinitialize the dataloader. We don't add this support for `evaluate` or `predict` since those functions iterate through the corresponding dataloader just once. For `fit`, this allows flexibility to reload training & evaluation dataloaders independently during if desired Differential Revision: D40539580 fbshipit-source-id: c67d7be4d20ac25f0b65927a20b525d86f8b56a4 --- tests/runner/test_fit.py | 32 +++++++++++++++++++++++++++++++- tests/runner/test_train.py | 28 +++++++++++++++++++++++++++- torchtnt/runner/evaluate.py | 4 ++++ torchtnt/runner/fit.py | 31 ++++++++++++++++++++++++++----- torchtnt/runner/state.py | 8 +++++++- torchtnt/runner/train.py | 17 ++++++++++++++--- 6 files changed, 109 insertions(+), 11 deletions(-) diff --git a/tests/runner/test_fit.py b/tests/runner/test_fit.py index 4b63cbb788..b172739473 100644 --- a/tests/runner/test_fit.py +++ b/tests/runner/test_fit.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Tuple +from typing import Iterable, Tuple from unittest.mock import MagicMock import torch @@ -277,3 +277,33 @@ def test_fit_with_callback(self) -> None: ) self.assertEqual(callback_mock.on_eval_epoch_end.call_count, max_epochs) self.assertEqual(callback_mock.on_eval_end.call_count, max_epochs) + + def test_fit_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock(spec=DummyFitUnit) + + train_dl_func = DataloaderFunc() + eval_dl_func = DataloaderFunc() + + state = init_fit_state( + train_dataloader=train_dl_func, + eval_dataloader=eval_dl_func, + max_epochs=max_epochs, + ) + fit(state, my_unit) + self.assertEqual(train_dl_func.call_count, max_epochs) + self.assertEqual(eval_dl_func.call_count, max_epochs) diff --git a/tests/runner/test_train.py b/tests/runner/test_train.py index f37f9c51b3..585927f9de 100644 --- a/tests/runner/test_train.py +++ b/tests/runner/test_train.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Iterator, Tuple +from typing import Iterable, Iterator, Tuple from unittest.mock import MagicMock import torch @@ -258,6 +258,32 @@ def test_train_max_steps(self) -> None: my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch ) + def test_train_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock() + + dl_func = DataloaderFunc() + state = init_train_state( + dataloader=dl_func, + max_epochs=max_epochs, + ) + train(state, my_unit) + self.assertEqual(dl_func.call_count, max_epochs) + class StopTrainUnit(TrainUnit[Tuple[torch.Tensor, torch.Tensor]]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index e0296f92d0..88ddbcf5db 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -104,6 +104,10 @@ def _evaluate_impl( eval_unit.on_eval_start(state) _run_callback_fn(callbacks, "on_eval_start", state, eval_unit) + dataloader_func = eval_state.dataloader_func + if dataloader_func: + eval_state._dataloader = dataloader_func(state) + # Conditionally run this to avoid running this multiple times # in the case of resuming from a checkpoint mid-epoch if eval_state.progress.num_steps_completed_in_epoch == 0: diff --git a/torchtnt/runner/fit.py b/torchtnt/runner/fit.py index 3ad66ae91a..99ee2e1cbb 100644 --- a/torchtnt/runner/fit.py +++ b/torchtnt/runner/fit.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Union from torchtnt.runner.callback import Callback @@ -19,8 +19,10 @@ def init_fit_state( - train_dataloader: Iterable[TTrainData], - eval_dataloader: Iterable[TEvalData], + train_dataloader: Union[ + Iterable[TTrainData], Callable[[State], Iterable[TTrainData]] + ], + eval_dataloader: Union[Iterable[TEvalData], Callable[[State], Iterable[TEvalData]]], max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_train_steps_per_epoch: Optional[int] = None, @@ -33,30 +35,49 @@ def init_fit_state( Args: train_dataloader: dataloader to be used during training. + This can either be any iterable, or a callable which accepts State as an argument and returns an iterable. eval_dataloader: dataloader to be used during evaluation. + This can either be any iterable, or a callable which accepts State as an argument and returns an iterable. max_epochs: the max number of epochs to run for training. ``None`` means no limit (infinite training) unless stopped by max_steps. max_steps: the max number of steps to run for training. ``None`` means no limit (infinite training) unless stopped by max_epochs. max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted. evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps. evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs. + train_dataloader_func: optional function to reinitialize the train dataloader at the start of the train epoch. + eval_dataloader_func: optional function to reinitialize the eval dataloader at the start of the eval epoch. Returns: An initialized state object containing metadata. """ + if callable(train_dataloader): + train_dl = None + train_dl_func = train_dataloader + else: + train_dl = train_dataloader + train_dl_func = None + + if callable(eval_dataloader): + eval_dl = None + eval_dl_func = eval_dataloader + else: + eval_dl = eval_dataloader + eval_dl_func = None return State( entry_point=EntryPoint.FIT, train_state=PhaseState( - dataloader=train_dataloader, + dataloader=train_dl, max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_train_steps_per_epoch, + dataloader_func=train_dl_func, ), eval_state=PhaseState( - dataloader=eval_dataloader, + dataloader=eval_dl, max_steps_per_epoch=max_eval_steps_per_epoch, evaluate_every_n_steps=evaluate_every_n_steps, evaluate_every_n_epochs=evaluate_every_n_epochs, + dataloader_func=eval_dl_func, ), ) diff --git a/torchtnt/runner/state.py b/torchtnt/runner/state.py index 92c697bde3..e2b8a02d1e 100644 --- a/torchtnt/runner/state.py +++ b/torchtnt/runner/state.py @@ -11,7 +11,7 @@ import logging from enum import auto, Enum -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional from torchtnt.runner.progress import Progress from torchtnt.utils.timer import Timer @@ -73,6 +73,7 @@ def __init__( max_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, # used only for evaluate evaluate_every_n_epochs: Optional[int] = None, # used only for evaluate + dataloader_func: Optional[Callable[["State"], Iterable[Any]]] = None, ) -> None: _check_loop_condition("max_epochs", max_epochs) _check_loop_condition("max_steps", max_steps) @@ -88,11 +89,16 @@ def __init__( self._evaluate_every_n_steps = evaluate_every_n_steps self._evaluate_every_n_epochs = evaluate_every_n_epochs self._step_output: Any = None + self._dataloader_func = dataloader_func @property def dataloader(self) -> Iterable[Any]: return self._dataloader + @property + def dataloader_func(self) -> Optional[Callable[["State"], Iterable[Any]]]: + return self._dataloader_func + @property def progress(self) -> Progress: return self._progress diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index ccb3064c71..fd02716157 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Union import torch from torchtnt.runner.callback import Callback @@ -29,7 +29,7 @@ def init_train_state( *, - dataloader: Iterable[TTrainData], + dataloader: Union[Iterable[TTrainData], Callable[[State], Iterable[TTrainData]]], max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_steps_per_epoch: Optional[int] = None, @@ -39,6 +39,7 @@ def init_train_state( Args: dataloader: dataloader to be used during training. + This can either be any iterable, or a callable which accepts State as an argument and returns an iterable. max_epochs: the max number of epochs to run. ``None`` means no limit (infinite training) unless stopped by max_steps. max_steps: the max number of steps to run. ``None`` means no limit (infinite training) unless stopped by max_epochs. max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted. @@ -46,14 +47,20 @@ def init_train_state( Returns: An initialized state object containing metadata. """ + dl, dataloader_func = None, None + if callable(dataloader): + dataloader_func = dataloader + else: + dl = dataloader return State( entry_point=EntryPoint.TRAIN, train_state=PhaseState( - dataloader=dataloader, + dataloader=dl, max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_steps_per_epoch, + dataloader_func=dataloader_func, ), ) @@ -196,6 +203,10 @@ def _train_epoch_impl( if state.eval_state.evaluate_every_n_epochs: evaluate_every_n_epochs = state.eval_state.evaluate_every_n_epochs + dataloader_func = train_state.dataloader_func + if dataloader_func: + train_state._dataloader = dataloader_func(state) + # Check the progress to conditionally run this # to avoid running this multiple times # in the case of resuming from a checkpoint mid-epoch