diff --git a/tests/runner/test_fit.py b/tests/runner/test_fit.py index 4b63cbb788..b172739473 100644 --- a/tests/runner/test_fit.py +++ b/tests/runner/test_fit.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Tuple +from typing import Iterable, Tuple from unittest.mock import MagicMock import torch @@ -277,3 +277,33 @@ def test_fit_with_callback(self) -> None: ) self.assertEqual(callback_mock.on_eval_epoch_end.call_count, max_epochs) self.assertEqual(callback_mock.on_eval_end.call_count, max_epochs) + + def test_fit_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock(spec=DummyFitUnit) + + train_dl_func = DataloaderFunc() + eval_dl_func = DataloaderFunc() + + state = init_fit_state( + train_dataloader=train_dl_func, + eval_dataloader=eval_dl_func, + max_epochs=max_epochs, + ) + fit(state, my_unit) + self.assertEqual(train_dl_func.call_count, max_epochs) + self.assertEqual(eval_dl_func.call_count, max_epochs) diff --git a/tests/runner/test_train.py b/tests/runner/test_train.py index f37f9c51b3..585927f9de 100644 --- a/tests/runner/test_train.py +++ b/tests/runner/test_train.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Iterator, Tuple +from typing import Iterable, Iterator, Tuple from unittest.mock import MagicMock import torch @@ -258,6 +258,32 @@ def test_train_max_steps(self) -> None: my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch ) + def test_train_dataloader_func(self) -> None: + input_dim = 2 + dataset_len = 8 + batch_size = 2 + max_epochs = 3 + + class DataloaderFunc: + def __init__(self) -> None: + self.call_count = 0 + + def __call__( + self, state: State + ) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]: + self.call_count += 1 + return generate_random_dataloader(dataset_len, input_dim, batch_size) + + my_unit = MagicMock() + + dl_func = DataloaderFunc() + state = init_train_state( + dataloader=dl_func, + max_epochs=max_epochs, + ) + train(state, my_unit) + self.assertEqual(dl_func.call_count, max_epochs) + class StopTrainUnit(TrainUnit[Tuple[torch.Tensor, torch.Tensor]]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index e0296f92d0..88ddbcf5db 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -104,6 +104,10 @@ def _evaluate_impl( eval_unit.on_eval_start(state) _run_callback_fn(callbacks, "on_eval_start", state, eval_unit) + dataloader_func = eval_state.dataloader_func + if dataloader_func: + eval_state._dataloader = dataloader_func(state) + # Conditionally run this to avoid running this multiple times # in the case of resuming from a checkpoint mid-epoch if eval_state.progress.num_steps_completed_in_epoch == 0: diff --git a/torchtnt/runner/fit.py b/torchtnt/runner/fit.py index 3ad66ae91a..99ee2e1cbb 100644 --- a/torchtnt/runner/fit.py +++ b/torchtnt/runner/fit.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Union from torchtnt.runner.callback import Callback @@ -19,8 +19,10 @@ def init_fit_state( - train_dataloader: Iterable[TTrainData], - eval_dataloader: Iterable[TEvalData], + train_dataloader: Union[ + Iterable[TTrainData], Callable[[State], Iterable[TTrainData]] + ], + eval_dataloader: Union[Iterable[TEvalData], Callable[[State], Iterable[TEvalData]]], max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_train_steps_per_epoch: Optional[int] = None, @@ -33,30 +35,49 @@ def init_fit_state( Args: train_dataloader: dataloader to be used during training. + This can either be any iterable, or a callable which accepts State as an argument and returns an iterable. eval_dataloader: dataloader to be used during evaluation. + This can either be any iterable, or a callable which accepts State as an argument and returns an iterable. max_epochs: the max number of epochs to run for training. ``None`` means no limit (infinite training) unless stopped by max_steps. max_steps: the max number of steps to run for training. ``None`` means no limit (infinite training) unless stopped by max_epochs. max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted. evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps. evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs. + train_dataloader_func: optional function to reinitialize the train dataloader at the start of the train epoch. + eval_dataloader_func: optional function to reinitialize the eval dataloader at the start of the eval epoch. Returns: An initialized state object containing metadata. """ + if callable(train_dataloader): + train_dl = None + train_dl_func = train_dataloader + else: + train_dl = train_dataloader + train_dl_func = None + + if callable(eval_dataloader): + eval_dl = None + eval_dl_func = eval_dataloader + else: + eval_dl = eval_dataloader + eval_dl_func = None return State( entry_point=EntryPoint.FIT, train_state=PhaseState( - dataloader=train_dataloader, + dataloader=train_dl, max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_train_steps_per_epoch, + dataloader_func=train_dl_func, ), eval_state=PhaseState( - dataloader=eval_dataloader, + dataloader=eval_dl, max_steps_per_epoch=max_eval_steps_per_epoch, evaluate_every_n_steps=evaluate_every_n_steps, evaluate_every_n_epochs=evaluate_every_n_epochs, + dataloader_func=eval_dl_func, ), ) diff --git a/torchtnt/runner/state.py b/torchtnt/runner/state.py index 92c697bde3..e2b8a02d1e 100644 --- a/torchtnt/runner/state.py +++ b/torchtnt/runner/state.py @@ -11,7 +11,7 @@ import logging from enum import auto, Enum -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional from torchtnt.runner.progress import Progress from torchtnt.utils.timer import Timer @@ -73,6 +73,7 @@ def __init__( max_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, # used only for evaluate evaluate_every_n_epochs: Optional[int] = None, # used only for evaluate + dataloader_func: Optional[Callable[["State"], Iterable[Any]]] = None, ) -> None: _check_loop_condition("max_epochs", max_epochs) _check_loop_condition("max_steps", max_steps) @@ -88,11 +89,16 @@ def __init__( self._evaluate_every_n_steps = evaluate_every_n_steps self._evaluate_every_n_epochs = evaluate_every_n_epochs self._step_output: Any = None + self._dataloader_func = dataloader_func @property def dataloader(self) -> Iterable[Any]: return self._dataloader + @property + def dataloader_func(self) -> Optional[Callable[["State"], Iterable[Any]]]: + return self._dataloader_func + @property def progress(self) -> Progress: return self._progress diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index ccb3064c71..fd02716157 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Union import torch from torchtnt.runner.callback import Callback @@ -29,7 +29,7 @@ def init_train_state( *, - dataloader: Iterable[TTrainData], + dataloader: Union[Iterable[TTrainData], Callable[[State], Iterable[TTrainData]]], max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_steps_per_epoch: Optional[int] = None, @@ -39,6 +39,7 @@ def init_train_state( Args: dataloader: dataloader to be used during training. + This can either be any iterable, or a callable which accepts State as an argument and returns an iterable. max_epochs: the max number of epochs to run. ``None`` means no limit (infinite training) unless stopped by max_steps. max_steps: the max number of steps to run. ``None`` means no limit (infinite training) unless stopped by max_epochs. max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted. @@ -46,14 +47,20 @@ def init_train_state( Returns: An initialized state object containing metadata. """ + dl, dataloader_func = None, None + if callable(dataloader): + dataloader_func = dataloader + else: + dl = dataloader return State( entry_point=EntryPoint.TRAIN, train_state=PhaseState( - dataloader=dataloader, + dataloader=dl, max_epochs=max_epochs, max_steps=max_steps, max_steps_per_epoch=max_steps_per_epoch, + dataloader_func=dataloader_func, ), ) @@ -196,6 +203,10 @@ def _train_epoch_impl( if state.eval_state.evaluate_every_n_epochs: evaluate_every_n_epochs = state.eval_state.evaluate_every_n_epochs + dataloader_func = train_state.dataloader_func + if dataloader_func: + train_state._dataloader = dataloader_func(state) + # Check the progress to conditionally run this # to avoid running this multiple times # in the case of resuming from a checkpoint mid-epoch