Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support recreating dataloaders during loop #248

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion tests/runner/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Tuple
from typing import Iterable, Tuple
from unittest.mock import MagicMock

import torch
Expand Down Expand Up @@ -277,3 +277,33 @@ def test_fit_with_callback(self) -> None:
)
self.assertEqual(callback_mock.on_eval_epoch_end.call_count, max_epochs)
self.assertEqual(callback_mock.on_eval_end.call_count, max_epochs)

def test_fit_dataloader_func(self) -> None:
input_dim = 2
dataset_len = 8
batch_size = 2
max_epochs = 3

class DataloaderFunc:
def __init__(self) -> None:
self.call_count = 0

def __call__(
self, state: State
) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]:
self.call_count += 1
return generate_random_dataloader(dataset_len, input_dim, batch_size)

my_unit = MagicMock(spec=DummyFitUnit)

train_dl_func = DataloaderFunc()
eval_dl_func = DataloaderFunc()

state = init_fit_state(
train_dataloader=train_dl_func,
eval_dataloader=eval_dl_func,
max_epochs=max_epochs,
)
fit(state, my_unit)
self.assertEqual(train_dl_func.call_count, max_epochs)
self.assertEqual(eval_dl_func.call_count, max_epochs)
28 changes: 27 additions & 1 deletion tests/runner/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Iterator, Tuple
from typing import Iterable, Iterator, Tuple
from unittest.mock import MagicMock

import torch
Expand Down Expand Up @@ -258,6 +258,32 @@ def test_train_max_steps(self) -> None:
my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch
)

def test_train_dataloader_func(self) -> None:
input_dim = 2
dataset_len = 8
batch_size = 2
max_epochs = 3

class DataloaderFunc:
def __init__(self) -> None:
self.call_count = 0

def __call__(
self, state: State
) -> Iterable[Tuple[torch.Tensor, torch.Tensor]]:
self.call_count += 1
return generate_random_dataloader(dataset_len, input_dim, batch_size)

my_unit = MagicMock()

dl_func = DataloaderFunc()
state = init_train_state(
dataloader=dl_func,
max_epochs=max_epochs,
)
train(state, my_unit)
self.assertEqual(dl_func.call_count, max_epochs)


class StopTrainUnit(TrainUnit[Tuple[torch.Tensor, torch.Tensor]]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
Expand Down
4 changes: 4 additions & 0 deletions torchtnt/runner/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def _evaluate_impl(
eval_unit.on_eval_start(state)
_run_callback_fn(callbacks, "on_eval_start", state, eval_unit)

dataloader_func = eval_state.dataloader_func
if dataloader_func:
eval_state._dataloader = dataloader_func(state)

# Conditionally run this to avoid running this multiple times
# in the case of resuming from a checkpoint mid-epoch
if eval_state.progress.num_steps_completed_in_epoch == 0:
Expand Down
31 changes: 26 additions & 5 deletions torchtnt/runner/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Iterable, List, Optional
from typing import Callable, Iterable, List, Optional, Union

from torchtnt.runner.callback import Callback

Expand All @@ -19,8 +19,10 @@


def init_fit_state(
train_dataloader: Iterable[TTrainData],
eval_dataloader: Iterable[TEvalData],
train_dataloader: Union[
Iterable[TTrainData], Callable[[State], Iterable[TTrainData]]
],
eval_dataloader: Union[Iterable[TEvalData], Callable[[State], Iterable[TEvalData]]],
max_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
max_train_steps_per_epoch: Optional[int] = None,
Expand All @@ -33,30 +35,49 @@ def init_fit_state(

Args:
train_dataloader: dataloader to be used during training.
This can either be any iterable, or a callable which accepts State as an argument and returns an iterable.
eval_dataloader: dataloader to be used during evaluation.
This can either be any iterable, or a callable which accepts State as an argument and returns an iterable.
max_epochs: the max number of epochs to run for training. ``None`` means no limit (infinite training) unless stopped by max_steps.
max_steps: the max number of steps to run for training. ``None`` means no limit (infinite training) unless stopped by max_epochs.
max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted.
evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps.
evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs.
train_dataloader_func: optional function to reinitialize the train dataloader at the start of the train epoch.
eval_dataloader_func: optional function to reinitialize the eval dataloader at the start of the eval epoch.

Returns:
An initialized state object containing metadata.
"""
if callable(train_dataloader):
train_dl = None
train_dl_func = train_dataloader
else:
train_dl = train_dataloader
train_dl_func = None

if callable(eval_dataloader):
eval_dl = None
eval_dl_func = eval_dataloader
else:
eval_dl = eval_dataloader
eval_dl_func = None

return State(
entry_point=EntryPoint.FIT,
train_state=PhaseState(
dataloader=train_dataloader,
dataloader=train_dl,
max_epochs=max_epochs,
max_steps=max_steps,
max_steps_per_epoch=max_train_steps_per_epoch,
dataloader_func=train_dl_func,
),
eval_state=PhaseState(
dataloader=eval_dataloader,
dataloader=eval_dl,
max_steps_per_epoch=max_eval_steps_per_epoch,
evaluate_every_n_steps=evaluate_every_n_steps,
evaluate_every_n_epochs=evaluate_every_n_epochs,
dataloader_func=eval_dl_func,
),
)

Expand Down
8 changes: 7 additions & 1 deletion torchtnt/runner/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
from enum import auto, Enum
from typing import Any, Iterable, Optional
from typing import Any, Callable, Iterable, Optional

from torchtnt.runner.progress import Progress
from torchtnt.utils.timer import Timer
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
max_steps_per_epoch: Optional[int] = None,
evaluate_every_n_steps: Optional[int] = None, # used only for evaluate
evaluate_every_n_epochs: Optional[int] = None, # used only for evaluate
dataloader_func: Optional[Callable[["State"], Iterable[Any]]] = None,
) -> None:
_check_loop_condition("max_epochs", max_epochs)
_check_loop_condition("max_steps", max_steps)
Expand All @@ -88,11 +89,16 @@ def __init__(
self._evaluate_every_n_steps = evaluate_every_n_steps
self._evaluate_every_n_epochs = evaluate_every_n_epochs
self._step_output: Any = None
self._dataloader_func = dataloader_func

@property
def dataloader(self) -> Iterable[Any]:
return self._dataloader

@property
def dataloader_func(self) -> Optional[Callable[["State"], Iterable[Any]]]:
return self._dataloader_func

@property
def progress(self) -> Progress:
return self._progress
Expand Down
17 changes: 14 additions & 3 deletions torchtnt/runner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Iterable, List, Optional
from typing import Callable, Iterable, List, Optional, Union

import torch
from torchtnt.runner.callback import Callback
Expand All @@ -29,7 +29,7 @@

def init_train_state(
*,
dataloader: Iterable[TTrainData],
dataloader: Union[Iterable[TTrainData], Callable[[State], Iterable[TTrainData]]],
max_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
max_steps_per_epoch: Optional[int] = None,
Expand All @@ -39,21 +39,28 @@ def init_train_state(

Args:
dataloader: dataloader to be used during training.
This can either be any iterable, or a callable which accepts State as an argument and returns an iterable.
max_epochs: the max number of epochs to run. ``None`` means no limit (infinite training) unless stopped by max_steps.
max_steps: the max number of steps to run. ``None`` means no limit (infinite training) unless stopped by max_epochs.
max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted.

Returns:
An initialized state object containing metadata.
"""
dl, dataloader_func = None, None
if callable(dataloader):
dataloader_func = dataloader
else:
dl = dataloader

return State(
entry_point=EntryPoint.TRAIN,
train_state=PhaseState(
dataloader=dataloader,
dataloader=dl,
max_epochs=max_epochs,
max_steps=max_steps,
max_steps_per_epoch=max_steps_per_epoch,
dataloader_func=dataloader_func,
),
)

Expand Down Expand Up @@ -196,6 +203,10 @@ def _train_epoch_impl(
if state.eval_state.evaluate_every_n_epochs:
evaluate_every_n_epochs = state.eval_state.evaluate_every_n_epochs

dataloader_func = train_state.dataloader_func
if dataloader_func:
train_state._dataloader = dataloader_func(state)

# Check the progress to conditionally run this
# to avoid running this multiple times
# in the case of resuming from a checkpoint mid-epoch
Expand Down