Skip to content

Commit

Permalink
Support recreating dataloaders during loop
Browse files Browse the repository at this point in the history
Summary:
Sometimes users need to recreate the dataloaders at the start of the epoch during the overall training loop. To support this, we allow users to register a creation function when creating the state for training or fitting. We pass the state as an argument since users may want to use progress information such as the progress counters to reinitialize the dataloader.

We don't add this support for `evaluate` or `predict` since those functions iterate through the corresponding dataloader just once.

For `fit`, this allows flexibility to reload training & evaluation dataloaders independently during if desired

Differential Revision: D40539580

fbshipit-source-id: 955f9d811aec000167a9d1ef5cb5a6e7ca0c62c6
  • Loading branch information
ananthsub authored and facebook-github-bot committed Oct 20, 2022
1 parent 4109143 commit 0ca1a11
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 6 deletions.
36 changes: 35 additions & 1 deletion tests/runner/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Tuple
from typing import Iterable, Tuple
from unittest.mock import MagicMock

import torch
Expand Down Expand Up @@ -277,3 +277,37 @@ def test_fit_with_callback(self) -> None:
)
self.assertEqual(callback_mock.on_eval_epoch_end.call_count, max_epochs)
self.assertEqual(callback_mock.on_eval_end.call_count, max_epochs)

def test_fit_dataloader_func(self) -> None:
input_dim = 2
dataset_len = 8
batch_size = 2
max_epochs = 3

class DataloaderFunc:
def __init__(self) -> None:
self.call_count = 0

def __call__(
self, state: State
) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]:
self.call_count += 1
return generate_random_dataloader(dataset_len, input_dim, batch_size)

my_unit = MagicMock(spec=DummyFitUnit)
train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)

train_dl_func = DataloaderFunc()
eval_dl_func = DataloaderFunc()

state = init_fit_state(
train_dataloader=train_dl,
eval_dataloader=eval_dl,
max_epochs=max_epochs,
train_dataloader_func=train_dl_func,
eval_dataloader_func=eval_dl_func,
)
fit(state, my_unit)
self.assertEqual(train_dl_func.call_count, max_epochs)
self.assertEqual(eval_dl_func.call_count, max_epochs)
30 changes: 29 additions & 1 deletion tests/runner/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Iterator, Tuple
from typing import Iterable, Iterator, Tuple
from unittest.mock import MagicMock

import torch
Expand Down Expand Up @@ -258,6 +258,34 @@ def test_train_max_steps(self) -> None:
my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch
)

def test_train_dataloader_func(self) -> None:
input_dim = 2
dataset_len = 8
batch_size = 2
max_epochs = 3

class DataloaderFunc:
def __init__(self) -> None:
self.call_count = 0

def __call__(
self, state: State
) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]:
self.call_count += 1
return generate_random_dataloader(dataset_len, input_dim, batch_size)

my_unit = MagicMock()
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)

dl_func = DataloaderFunc()
state = init_train_state(
dataloader=dataloader,
max_epochs=max_epochs,
dataloader_func=dl_func,
)
train(state, my_unit)
self.assertEqual(dl_func.call_count, max_epochs)


class StopTrainUnit(TrainUnit[Tuple[torch.Tensor, torch.Tensor]]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
Expand Down
4 changes: 4 additions & 0 deletions torchtnt/runner/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def _evaluate_impl(
eval_unit.on_eval_epoch_start(state)
_run_callback_fn(callbacks, "on_eval_epoch_start", state, eval_unit)

dataloader_func = eval_state.dataloader_func
if dataloader_func:
eval_state._dataloader = dataloader_func(state)

data_iter = iter(eval_state.dataloader)
step_input = data_iter

Expand Down
8 changes: 7 additions & 1 deletion torchtnt/runner/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Iterable, List, Optional
from typing import Callable, Iterable, List, Optional

from torchtnt.runner.callback import Callback

Expand All @@ -27,6 +27,8 @@ def init_fit_state(
max_eval_steps_per_epoch: Optional[int] = None,
evaluate_every_n_steps: Optional[int] = None,
evaluate_every_n_epochs: Optional[int] = 1,
train_dataloader_func: Optional[Callable[[State], Iterable[TTrainData]]] = None,
eval_dataloader_func: Optional[Callable[[State], Iterable[TEvalData]]] = None,
) -> State:
"""
Helper function that initializes a state object for fitting.
Expand All @@ -39,6 +41,8 @@ def init_fit_state(
max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted.
evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps.
evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs.
train_dataloader_func: optional function to reinitialize the train dataloader at the start of the train epoch.
eval_dataloader_func: optional function to reinitialize the eval dataloader at the start of the eval epoch.
Returns:
An initialized state object containing metadata.
Expand All @@ -51,12 +55,14 @@ def init_fit_state(
max_epochs=max_epochs,
max_steps=max_steps,
max_steps_per_epoch=max_train_steps_per_epoch,
dataloader_func=train_dataloader_func,
),
eval_state=PhaseState(
dataloader=eval_dataloader,
max_steps_per_epoch=max_eval_steps_per_epoch,
evaluate_every_n_steps=evaluate_every_n_steps,
evaluate_every_n_epochs=evaluate_every_n_epochs,
dataloader_func=eval_dataloader_func,
),
)

Expand Down
8 changes: 7 additions & 1 deletion torchtnt/runner/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
from enum import auto, Enum
from typing import Any, Iterable, Optional
from typing import Any, Callable, Iterable, Optional

from torchtnt.runner.progress import Progress
from torchtnt.utils.timer import Timer
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
max_steps_per_epoch: Optional[int] = None,
evaluate_every_n_steps: Optional[int] = None, # used only for evaluate
evaluate_every_n_epochs: Optional[int] = None, # used only for evaluate
dataloader_func: Optional[Callable[["State"], Iterable[Any]]] = None,
) -> None:
_check_loop_condition("max_epochs", max_epochs)
_check_loop_condition("max_steps", max_steps)
Expand All @@ -88,11 +89,16 @@ def __init__(
self._evaluate_every_n_steps = evaluate_every_n_steps
self._evaluate_every_n_epochs = evaluate_every_n_epochs
self._step_output: Any = None
self._dataloader_func = dataloader_func

@property
def dataloader(self) -> Iterable[Any]:
return self._dataloader

@property
def dataloader_func(self) -> Optional[Callable[["State"], Iterable[Any]]]:
return self._dataloader_func

@property
def progress(self) -> Progress:
return self._progress
Expand Down
10 changes: 8 additions & 2 deletions torchtnt/runner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Iterable, List, Optional
from typing import Callable, Iterable, List, Optional

import torch
from torchtnt.runner.callback import Callback
Expand Down Expand Up @@ -33,6 +33,7 @@ def init_train_state(
max_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
max_steps_per_epoch: Optional[int] = None,
dataloader_func: Optional[Callable[[State], Iterable[TTrainData]]] = None,
) -> State:
"""
Helper function that initializes a state object for training.
Expand All @@ -42,7 +43,7 @@ def init_train_state(
max_epochs: the max number of epochs to run. `None` means no limit (infinite training) unless stopped by max_steps.
max_steps: the max number of steps to run. `None` means no limit (infinite training) unless stopped by max_epochs.
max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted.
dataloader_func: optional function to reinitialize the dataloader during training at the start of the epoch.
Returns:
An initialized state object containing metadata.
"""
Expand All @@ -54,6 +55,7 @@ def init_train_state(
max_epochs=max_epochs,
max_steps=max_steps,
max_steps_per_epoch=max_steps_per_epoch,
dataloader_func=dataloader_func,
),
)

Expand Down Expand Up @@ -195,6 +197,10 @@ def _train_epoch_impl(
train_unit.on_train_epoch_start(state)
_run_callback_fn(callbacks, "on_train_epoch_start", state, train_unit)

dataloader_func = train_state.dataloader_func
if dataloader_func:
train_state._dataloader = dataloader_func(state)

_maybe_set_distributed_sampler_epoch(
train_state.dataloader, train_state.progress.num_epochs_completed
)
Expand Down

0 comments on commit 0ca1a11

Please sign in to comment.