diff --git a/tests/runner/test_fit.py b/tests/runner/test_fit.py index 4b63cbb788..0d99b30fb4 100644 --- a/tests/runner/test_fit.py +++ b/tests/runner/test_fit.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Tuple +from typing import Iterable, Tuple from unittest.mock import MagicMock import torch @@ -277,3 +277,37 @@ def test_fit_with_callback(self) -> None: ) self.assertEqual(callback_mock.on_eval_epoch_end.call_count, max_epochs) self.assertEqual(callback_mock.on_eval_end.call_count, max_epochs) + + def test_fit_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock(spec=DummyFitUnit) + train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) + eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) + + train_dl_func = DataloaderFunc() + eval_dl_func = DataloaderFunc() + + state = init_fit_state( + train_dataloader=train_dl, + eval_dataloader=eval_dl, + max_epochs=max_epochs, + train_dataloader_func=train_dl_func, + eval_dataloader_func=eval_dl_func, + ) + fit(state, my_unit) + self.assertEqual(train_dl_func.call_count, max_epochs) + self.assertEqual(eval_dl_func.call_count, max_epochs) diff --git a/tests/runner/test_train.py b/tests/runner/test_train.py index f37f9c51b3..a648188ed6 100644 --- a/tests/runner/test_train.py +++ b/tests/runner/test_train.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Iterator, Tuple +from typing import Iterable, Iterator, Tuple from unittest.mock import MagicMock import torch @@ -258,6 +258,34 @@ def test_train_max_steps(self) -> None: my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch ) + def test_train_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock() + dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) + + dl_func = DataloaderFunc() + state = init_train_state( + dataloader=dataloader, + max_epochs=max_epochs, + dataloader_func=dl_func, + ) + train(state, my_unit) + self.assertEqual(dl_func.call_count, max_epochs) + class StopTrainUnit(TrainUnit[Tuple[torch.Tensor, torch.Tensor]]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index a7b760b0c2..fbcab40326 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -113,6 +113,10 @@ def _evaluate_impl( eval_unit.on_eval_epoch_start(state) _run_callback_fn(callbacks, "on_eval_epoch_start", state, eval_unit) + dataloader_func = eval_state.dataloader_func + if dataloader_func: + eval_state._dataloader = dataloader_func(state) + data_iter = iter(eval_state.dataloader) step_input = data_iter diff --git a/torchtnt/runner/fit.py b/torchtnt/runner/fit.py index d584d9f019..dfebd22b74 100644 --- a/torchtnt/runner/fit.py +++ b/torchtnt/runner/fit.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional from torchtnt.runner.callback import Callback @@ -27,6 +27,8 @@ def init_fit_state( max_eval_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, evaluate_every_n_epochs: Optional[int] = 1, + train_dataloader_func: Optional[Callable[[State], Iterable[TTrainData]]] = None, + eval_dataloader_func: Optional[Callable[[State], Iterable[TEvalData]]] = None, ) -> State: """ Helper function that initializes a state object for fitting. @@ -39,6 +41,8 @@ def init_fit_state( max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted. evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps. evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs. + train_dataloader_func: optional function to reinitialize the train dataloader at the start of the train epoch. + eval_dataloader_func: optional function to reinitialize the eval dataloader at the start of the eval epoch. Returns: An initialized state object containing metadata. @@ -51,12 +55,14 @@ def init_fit_state( max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_train_steps_per_epoch, + dataloader_func=train_dataloader_func, ), eval_state=PhaseState( dataloader=eval_dataloader, max_steps_per_epoch=max_eval_steps_per_epoch, evaluate_every_n_steps=evaluate_every_n_steps, evaluate_every_n_epochs=evaluate_every_n_epochs, + dataloader_func=eval_dataloader_func, ), ) diff --git a/torchtnt/runner/state.py b/torchtnt/runner/state.py index 366a828481..002bdc1d75 100644 --- a/torchtnt/runner/state.py +++ b/torchtnt/runner/state.py @@ -11,7 +11,7 @@ import logging from enum import auto, Enum -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional from torchtnt.runner.progress import Progress from torchtnt.utils.timer import Timer @@ -73,6 +73,7 @@ def __init__( max_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, # used only for evaluate evaluate_every_n_epochs: Optional[int] = None, # used only for evaluate + dataloader_func: Optional[Callable[["State"], Iterable[Any]]] = None, ) -> None: _check_loop_condition("max_epochs", max_epochs) _check_loop_condition("max_steps", max_steps) @@ -88,11 +89,16 @@ def __init__( self._evaluate_every_n_steps = evaluate_every_n_steps self._evaluate_every_n_epochs = evaluate_every_n_epochs self._step_output: Any = None + self._dataloader_func = dataloader_func @property def dataloader(self) -> Iterable[Any]: return self._dataloader + @property + def dataloader_func(self) -> Optional[Callable[["State"], Iterable[Any]]]: + return self._dataloader_func + @property def progress(self) -> Progress: return self._progress diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index 0434d305ca..7142b12b13 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional import torch from torchtnt.runner.callback import Callback @@ -33,6 +33,7 @@ def init_train_state( max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_steps_per_epoch: Optional[int] = None, + dataloader_func: Optional[Callable[[State], Iterable[TTrainData]]] = None, ) -> State: """ Helper function that initializes a state object for training. @@ -42,7 +43,7 @@ def init_train_state( max_epochs: the max number of epochs to run. `None` means no limit (infinite training) unless stopped by max_steps. max_steps: the max number of steps to run. `None` means no limit (infinite training) unless stopped by max_epochs. max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted. - + dataloader_func: optional function to reinitialize the dataloader during training at the start of the epoch. Returns: An initialized state object containing metadata. """ @@ -54,6 +55,7 @@ def init_train_state( max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_steps_per_epoch, + dataloader_func=dataloader_func, ), ) @@ -195,6 +197,10 @@ def _train_epoch_impl( train_unit.on_train_epoch_start(state) _run_callback_fn(callbacks, "on_train_epoch_start", state, train_unit) + dataloader_func = train_state.dataloader_func + if dataloader_func: + train_state._dataloader = dataloader_func(state) + _maybe_set_distributed_sampler_epoch( train_state.dataloader, train_state.progress.num_epochs_completed )