From 0ca1a115d313b6567807162b9c23739c8b4c8a88 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Wed, 19 Oct 2022 23:03:25 -0700 Subject: [PATCH] Support recreating dataloaders during loop Summary: Sometimes users need to recreate the dataloaders at the start of the epoch during the overall training loop. To support this, we allow users to register a creation function when creating the state for training or fitting. We pass the state as an argument since users may want to use progress information such as the progress counters to reinitialize the dataloader. We don't add this support for `evaluate` or `predict` since those functions iterate through the corresponding dataloader just once. For `fit`, this allows flexibility to reload training & evaluation dataloaders independently during if desired Differential Revision: D40539580 fbshipit-source-id: 955f9d811aec000167a9d1ef5cb5a6e7ca0c62c6 --- tests/runner/test_fit.py | 36 +++++++++++++++++++++++++++++++++++- tests/runner/test_train.py | 30 +++++++++++++++++++++++++++++- torchtnt/runner/evaluate.py | 4 ++++ torchtnt/runner/fit.py | 8 +++++++- torchtnt/runner/state.py | 8 +++++++- torchtnt/runner/train.py | 10 ++++++++-- 6 files changed, 90 insertions(+), 6 deletions(-) diff --git a/tests/runner/test_fit.py b/tests/runner/test_fit.py index 4b63cbb788..0d99b30fb4 100644 --- a/tests/runner/test_fit.py +++ b/tests/runner/test_fit.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Tuple +from typing import Iterable, Tuple from unittest.mock import MagicMock import torch @@ -277,3 +277,37 @@ def test_fit_with_callback(self) -> None: ) self.assertEqual(callback_mock.on_eval_epoch_end.call_count, max_epochs) self.assertEqual(callback_mock.on_eval_end.call_count, max_epochs) + + def test_fit_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock(spec=DummyFitUnit) + train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) + eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) + + train_dl_func = DataloaderFunc() + eval_dl_func = DataloaderFunc() + + state = init_fit_state( + train_dataloader=train_dl, + eval_dataloader=eval_dl, + max_epochs=max_epochs, + train_dataloader_func=train_dl_func, + eval_dataloader_func=eval_dl_func, + ) + fit(state, my_unit) + self.assertEqual(train_dl_func.call_count, max_epochs) + self.assertEqual(eval_dl_func.call_count, max_epochs) diff --git a/tests/runner/test_train.py b/tests/runner/test_train.py index f37f9c51b3..a648188ed6 100644 --- a/tests/runner/test_train.py +++ b/tests/runner/test_train.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Iterator, Tuple +from typing import Iterable, Iterator, Tuple from unittest.mock import MagicMock import torch @@ -258,6 +258,34 @@ def test_train_max_steps(self) -> None: my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch ) + def test_train_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock() + dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) + + dl_func = DataloaderFunc() + state = init_train_state( + dataloader=dataloader, + max_epochs=max_epochs, + dataloader_func=dl_func, + ) + train(state, my_unit) + self.assertEqual(dl_func.call_count, max_epochs) + class StopTrainUnit(TrainUnit[Tuple[torch.Tensor, torch.Tensor]]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index a7b760b0c2..fbcab40326 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -113,6 +113,10 @@ def _evaluate_impl( eval_unit.on_eval_epoch_start(state) _run_callback_fn(callbacks, "on_eval_epoch_start", state, eval_unit) + dataloader_func = eval_state.dataloader_func + if dataloader_func: + eval_state._dataloader = dataloader_func(state) + data_iter = iter(eval_state.dataloader) step_input = data_iter diff --git a/torchtnt/runner/fit.py b/torchtnt/runner/fit.py index d584d9f019..dfebd22b74 100644 --- a/torchtnt/runner/fit.py +++ b/torchtnt/runner/fit.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional from torchtnt.runner.callback import Callback @@ -27,6 +27,8 @@ def init_fit_state( max_eval_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, evaluate_every_n_epochs: Optional[int] = 1, + train_dataloader_func: Optional[Callable[[State], Iterable[TTrainData]]] = None, + eval_dataloader_func: Optional[Callable[[State], Iterable[TEvalData]]] = None, ) -> State: """ Helper function that initializes a state object for fitting. @@ -39,6 +41,8 @@ def init_fit_state( max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted. evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps. evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs. + train_dataloader_func: optional function to reinitialize the train dataloader at the start of the train epoch. + eval_dataloader_func: optional function to reinitialize the eval dataloader at the start of the eval epoch. Returns: An initialized state object containing metadata. @@ -51,12 +55,14 @@ def init_fit_state( max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_train_steps_per_epoch, + dataloader_func=train_dataloader_func, ), eval_state=PhaseState( dataloader=eval_dataloader, max_steps_per_epoch=max_eval_steps_per_epoch, evaluate_every_n_steps=evaluate_every_n_steps, evaluate_every_n_epochs=evaluate_every_n_epochs, + dataloader_func=eval_dataloader_func, ), ) diff --git a/torchtnt/runner/state.py b/torchtnt/runner/state.py index 366a828481..002bdc1d75 100644 --- a/torchtnt/runner/state.py +++ b/torchtnt/runner/state.py @@ -11,7 +11,7 @@ import logging from enum import auto, Enum -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional from torchtnt.runner.progress import Progress from torchtnt.utils.timer import Timer @@ -73,6 +73,7 @@ def __init__( max_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, # used only for evaluate evaluate_every_n_epochs: Optional[int] = None, # used only for evaluate + dataloader_func: Optional[Callable[["State"], Iterable[Any]]] = None, ) -> None: _check_loop_condition("max_epochs", max_epochs) _check_loop_condition("max_steps", max_steps) @@ -88,11 +89,16 @@ def __init__( self._evaluate_every_n_steps = evaluate_every_n_steps self._evaluate_every_n_epochs = evaluate_every_n_epochs self._step_output: Any = None + self._dataloader_func = dataloader_func @property def dataloader(self) -> Iterable[Any]: return self._dataloader + @property + def dataloader_func(self) -> Optional[Callable[["State"], Iterable[Any]]]: + return self._dataloader_func + @property def progress(self) -> Progress: return self._progress diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index 0434d305ca..7142b12b13 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional import torch from torchtnt.runner.callback import Callback @@ -33,6 +33,7 @@ def init_train_state( max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_steps_per_epoch: Optional[int] = None, + dataloader_func: Optional[Callable[[State], Iterable[TTrainData]]] = None, ) -> State: """ Helper function that initializes a state object for training. @@ -42,7 +43,7 @@ def init_train_state( max_epochs: the max number of epochs to run. `None` means no limit (infinite training) unless stopped by max_steps. max_steps: the max number of steps to run. `None` means no limit (infinite training) unless stopped by max_epochs. max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted. - + dataloader_func: optional function to reinitialize the dataloader during training at the start of the epoch. Returns: An initialized state object containing metadata. """ @@ -54,6 +55,7 @@ def init_train_state( max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_steps_per_epoch, + dataloader_func=dataloader_func, ), ) @@ -195,6 +197,10 @@ def _train_epoch_impl( train_unit.on_train_epoch_start(state) _run_callback_fn(callbacks, "on_train_epoch_start", state, train_unit) + dataloader_func = train_state.dataloader_func + if dataloader_func: + train_state._dataloader = dataloader_func(state) + _maybe_set_distributed_sampler_epoch( train_state.dataloader, train_state.progress.num_epochs_completed )