From 1a2f5a118ea87c8928c39294b94c89ea482e1ec9 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Tue, 26 Sep 2023 10:13:00 -0700 Subject: [PATCH] add StepOutput generic to Unit Summary: Add a second generic to `Unit`, denoting the `step_output` type. Differential Revision: D49391712 --- examples/auto_unit_example.py | 14 +++---- examples/mingpt/main.py | 3 +- examples/mnist/main.py | 3 +- examples/torchdata_train_example.py | 2 +- examples/torchrec/main.py | 11 ++---- examples/train_unit_example.py | 2 +- tests/framework/callbacks/test_csv_writer.py | 6 +-- tests/framework/callbacks/test_lambda.py | 2 +- tests/framework/test_auto_unit.py | 29 +++++++------- tests/framework/test_evaluate.py | 20 ++++++---- tests/framework/test_fit.py | 11 ++---- tests/framework/test_predict.py | 10 +++-- tests/framework/test_train.py | 15 +++++--- torchtnt/framework/_test_utils.py | 26 ++++++------- torchtnt/framework/auto_unit.py | 40 +++++++++----------- torchtnt/framework/unit.py | 28 ++++++++------ 16 files changed, 112 insertions(+), 110 deletions(-) diff --git a/examples/auto_unit_example.py b/examples/auto_unit_example.py index eef55c1102..06f8339e3c 100644 --- a/examples/auto_unit_example.py +++ b/examples/auto_unit_example.py @@ -28,6 +28,7 @@ Batch = Tuple[torch.Tensor, torch.Tensor] +ModelStepOutput = torch.Tensor NUM_PROCESSES = 2 @@ -60,7 +61,7 @@ def prepare_dataloader( ) -class MyUnit(AutoUnit[Batch]): +class MyUnit(AutoUnit[Batch, ModelStepOutput]): # pyre-fixme[3]: Return type must be annotated. def __init__( self, @@ -110,8 +111,9 @@ def configure_optimizers_and_lr_scheduler( lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) return optimizer, lr_scheduler - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: + def compute_loss( + self, state: State, data: Batch + ) -> Tuple[torch.Tensor, ModelStepOutput]: inputs, targets = data # convert targets to float Tensor for binary_cross_entropy_with_logits targets = targets.float() @@ -127,8 +129,7 @@ def on_train_step_end( data: Batch, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: ModelStepOutput, ) -> None: _, targets = data self.train_accuracy.update(outputs, targets) @@ -143,8 +144,7 @@ def on_eval_step_end( data: Batch, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: ModelStepOutput, ) -> None: _, targets = data self.eval_accuracy.update(outputs, targets) diff --git a/examples/mingpt/main.py b/examples/mingpt/main.py index f489ab8917..27904ec979 100644 --- a/examples/mingpt/main.py +++ b/examples/mingpt/main.py @@ -32,6 +32,7 @@ logging.basicConfig(level=logging.INFO) Batch = Tuple[torch.Tensor, torch.Tensor] +ModelStepOutput = torch.Tensor PATH: str = parutil.get_file_path("data/input.txt", pkg=__package__) @@ -59,7 +60,7 @@ def get_datasets( return train_set, eval_set, dataset -class MinGPTUnit(AutoUnit[Batch]): +class MinGPTUnit(AutoUnit[Batch, ModelStepOutput]): def __init__( self, tb_logger: TensorBoardLogger, diff --git a/examples/mnist/main.py b/examples/mnist/main.py index f301a1b996..20d42c5903 100644 --- a/examples/mnist/main.py +++ b/examples/mnist/main.py @@ -26,6 +26,7 @@ from torchvision import datasets, transforms Batch = Tuple[torch.Tensor, torch.Tensor] +ModelStepOutput = torch.Tensor class Net(nn.Module): @@ -56,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output -class MyUnit(AutoUnit[Batch]): +class MyUnit(AutoUnit[Batch, ModelStepOutput]): def __init__( self, *, diff --git a/examples/torchdata_train_example.py b/examples/torchdata_train_example.py index 5df4c62b3e..a4c6f17c6b 100644 --- a/examples/torchdata_train_example.py +++ b/examples/torchdata_train_example.py @@ -78,7 +78,7 @@ def prepare_dataloader( return dataloader -class MyTrainUnit(TrainUnit[Batch]): +class MyTrainUnit(TrainUnit[Batch, None]): def __init__( self, module: torch.nn.Module, diff --git a/examples/torchrec/main.py b/examples/torchrec/main.py index d724bdfcfa..413b77ede0 100644 --- a/examples/torchrec/main.py +++ b/examples/torchrec/main.py @@ -180,10 +180,10 @@ def parse_args(argv: List[str]) -> argparse.Namespace: return parser.parse_args(argv) -Batch = Tuple[torch.Tensor, torch.Tensor] +Batch = Iterator[Tuple[torch.Tensor, torch.Tensor]] -class MyUnit(TrainUnit[Batch], EvalUnit[Batch]): +class MyUnit(TrainUnit[Batch, None], EvalUnit[Batch, None]): def __init__( self, module: torch.nn.Module, @@ -205,9 +205,7 @@ def __init__( self.tb_logger = tb_logger self.log_every_n_steps = log_every_n_steps - # pyre-fixme[14]: `train_step` overrides method defined in `TrainUnit` - # inconsistently. - def train_step(self, state: State, data: Iterator[Batch]) -> None: + def train_step(self, state: State, data: Batch) -> None: step = self.train_progress.num_steps_completed loss, logits, labels = self.pipeline.progress(data) preds = torch.sigmoid(logits) @@ -222,8 +220,7 @@ def on_train_epoch_end(self, state: State) -> None: # reset the metric every epoch self.train_auroc.reset() - # pyre-fixme[14]: `eval_step` overrides method defined in `EvalUnit` inconsistently. - def eval_step(self, state: State, data: Iterator[Batch]) -> None: + def eval_step(self, state: State, data: Batch) -> None: step = self.eval_progress.num_steps_completed loss, _, _ = self.pipeline.progress(data) if step % self.log_every_n_steps == 0: diff --git a/examples/train_unit_example.py b/examples/train_unit_example.py index aeff258b9f..afbfa5e215 100644 --- a/examples/train_unit_example.py +++ b/examples/train_unit_example.py @@ -57,7 +57,7 @@ def prepare_dataloader( ) -class MyTrainUnit(TrainUnit[Batch]): +class MyTrainUnit(TrainUnit[Batch, None]): def __init__( self, module: torch.nn.Module, diff --git a/tests/framework/callbacks/test_csv_writer.py b/tests/framework/callbacks/test_csv_writer.py index c9b4aee8f0..15eee1866d 100644 --- a/tests/framework/callbacks/test_csv_writer.py +++ b/tests/framework/callbacks/test_csv_writer.py @@ -13,7 +13,7 @@ from torchtnt.framework.callbacks.base_csv_writer import BaseCSVWriter from torchtnt.framework.predict import predict from torchtnt.framework.state import State -from torchtnt.framework.unit import PredictUnit, TPredictData +from torchtnt.framework.unit import TPredictUnit _HEADER_ROW = ["output"] _FILENAME = "test_csv_writer.csv" @@ -23,7 +23,7 @@ class CustomCSVWriter(BaseCSVWriter): def get_step_output_rows( self, state: State, - unit: PredictUnit[TPredictData], + unit: TPredictUnit, # pyre-fixme[2]: Parameter annotation cannot be `Any`. step_output: Any, ) -> Union[List[str], List[List[str]]]: @@ -34,7 +34,7 @@ class CustomCSVWriterSingleRow(BaseCSVWriter): def get_step_output_rows( self, state: State, - unit: PredictUnit[TPredictData], + unit: TPredictUnit, # pyre-fixme[2]: Parameter annotation cannot be `Any`. step_output: Any, ) -> Union[List[str], List[List[str]]]: diff --git a/tests/framework/callbacks/test_lambda.py b/tests/framework/callbacks/test_lambda.py index 3492c537e8..502d75a0ff 100644 --- a/tests/framework/callbacks/test_lambda.py +++ b/tests/framework/callbacks/test_lambda.py @@ -28,7 +28,7 @@ Batch = Tuple[torch.Tensor, torch.Tensor] -class DummyTrainExceptUnit(TrainUnit[Batch]): +class DummyTrainExceptUnit(TrainUnit[Batch, None]): def __init__(self, input_dim: int) -> None: super().__init__() diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 37622e9eed..27c21ecdae 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -738,14 +738,15 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None: Batch = Tuple[torch.Tensor, torch.Tensor] -class DummyLRSchedulerAutoUnit(AutoUnit[Batch]): +class DummyLRSchedulerAutoUnit(AutoUnit[Batch, torch.Tensor]): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: + def compute_loss( + self, state: State, data: Batch + ) -> Tuple[torch.Tensor, torch.Tensor]: inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.cross_entropy(outputs, targets) @@ -760,15 +761,16 @@ def configure_optimizers_and_lr_scheduler( return my_optimizer, my_lr_scheduler -class DummyComplexAutoUnit(AutoUnit[Batch]): +class DummyComplexAutoUnit(AutoUnit[Batch, torch.Tensor]): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def __init__(self, lr: float, *args, **kwargs): super().__init__(*args, **kwargs) self.lr = lr - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: + def compute_loss( + self, state: State, data: Batch + ) -> Tuple[torch.Tensor, torch.Tensor]: inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.cross_entropy(outputs, targets) @@ -783,7 +785,7 @@ def configure_optimizers_and_lr_scheduler( return my_optimizer, my_lr_scheduler -class LastBatchAutoUnit(AutoUnit[Batch]): +class LastBatchAutoUnit(AutoUnit[Batch, torch.Tensor]): def __init__(self, module: torch.nn.Module, expected_steps_per_epoch: int) -> None: super().__init__(module=module) self.expected_steps_per_epoch = expected_steps_per_epoch @@ -812,7 +814,7 @@ def configure_optimizers_and_lr_scheduler( return my_optimizer, my_lr_scheduler -class TimingAutoUnit(AutoUnit[Batch]): +class TimingAutoUnit(AutoUnit[Batch, torch.Tensor]): def __init__(self, module: torch.nn.Module) -> None: super().__init__(module=module) self.loss_fn = torch.nn.CrossEntropyLoss() @@ -839,8 +841,7 @@ def on_train_step_end( data: Batch, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: torch.Tensor, ) -> None: assert state.train_state if self.train_progress.num_steps_completed_in_epoch == 1: @@ -871,8 +872,7 @@ def on_eval_step_end( data: Batch, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: torch.Tensor, ) -> None: if self.eval_progress.num_steps_completed_in_epoch == 1: tc = unittest.TestCase() @@ -918,7 +918,7 @@ def on_predict_step_end( tc.assertNotIn("TimingAutoUnit.predict_step", recorded_timer_keys) -class TimingAutoPredictUnit(AutoPredictUnit[Batch]): +class TimingAutoPredictUnit(AutoPredictUnit[Batch, torch.Tensor]): def __init__(self, module: torch.nn.Module) -> None: super().__init__(module=module) self.loss_fn = torch.nn.CrossEntropyLoss() @@ -944,8 +944,7 @@ def on_predict_step_end( state: State, data: TPredictData, step: int, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: torch.Tensor, ) -> None: if self.predict_progress.num_steps_completed_in_epoch == 1: tc = unittest.TestCase() diff --git a/tests/framework/test_evaluate.py b/tests/framework/test_evaluate.py index 46120e30b6..253de47a4c 100644 --- a/tests/framework/test_evaluate.py +++ b/tests/framework/test_evaluate.py @@ -85,7 +85,12 @@ def test_evaluate_stop(self) -> None: self.assertEqual(my_unit.steps_processed, steps_before_stopping) def test_evaluate_data_iter_step(self) -> None: - class EvalIteratorUnit(EvalUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]): + class EvalIteratorUnit( + EvalUnit[ + Iterator[Tuple[torch.Tensor, torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor], + ] + ): def __init__(self, input_dim: int) -> None: super().__init__() self.module = nn.Linear(input_dim, 2) @@ -204,7 +209,11 @@ def test_evaluate_timing(self) -> None: self.assertIn("evaluate.next(data_iter)", timer.recorded_durations.keys()) -class StopEvalUnit(EvalUnit[Tuple[torch.Tensor, torch.Tensor]]): +Batch = Tuple[torch.Tensor, torch.Tensor] +StepOutput = Tuple[torch.Tensor, torch.Tensor] + + +class StopEvalUnit(EvalUnit[Batch, StepOutput]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: super().__init__() # initialize module & loss_fn @@ -213,9 +222,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None: self.steps_processed = 0 self.steps_before_stopping = steps_before_stopping - def eval_step( - self, state: State, data: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def eval_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data @@ -231,6 +238,3 @@ def eval_step( self.steps_processed += 1 return loss, outputs - - -Batch = Tuple[torch.Tensor, torch.Tensor] diff --git a/tests/framework/test_fit.py b/tests/framework/test_fit.py index c25e4fd7c6..22a98f650d 100644 --- a/tests/framework/test_fit.py +++ b/tests/framework/test_fit.py @@ -123,8 +123,9 @@ def test_fit_evaluate_every_n_steps(self) -> None: def test_fit_stop(self) -> None: Batch = Tuple[torch.Tensor, torch.Tensor] + StepOutput = Tuple[torch.Tensor, torch.Tensor] - class FitStop(TrainUnit[Batch], EvalUnit[Batch]): + class FitStop(TrainUnit[Batch, StepOutput], EvalUnit[Batch, StepOutput]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: super().__init__() # initialize module, loss_fn, & optimizer @@ -134,9 +135,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None: self.steps_processed = 0 self.steps_before_stopping = steps_before_stopping - def train_step( - self, state: State, data: Batch - ) -> Tuple[torch.Tensor, torch.Tensor]: + def train_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) @@ -156,9 +155,7 @@ def train_step( self.steps_processed += 1 return loss, outputs - def eval_step( - self, state: State, data: Batch - ) -> Tuple[torch.Tensor, torch.Tensor]: + def eval_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) loss = self.loss_fn(outputs, targets) diff --git a/tests/framework/test_predict.py b/tests/framework/test_predict.py index 041411d1df..84f29f6627 100644 --- a/tests/framework/test_predict.py +++ b/tests/framework/test_predict.py @@ -158,7 +158,10 @@ def on_predict_end(self, state: State, unit: TPredictUnit) -> None: def test_predict_data_iter_step(self) -> None: class PredictIteratorUnit( - PredictUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]] + PredictUnit[ + Iterator[Tuple[torch.Tensor, torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor], + ] ): def __init__(self, input_dim: int) -> None: super().__init__() @@ -213,9 +216,10 @@ def test_predict_timing(self) -> None: Batch = Tuple[torch.Tensor, torch.Tensor] +StepOutput = torch.Tensor -class StopPredictUnit(PredictUnit[Batch]): +class StopPredictUnit(PredictUnit[Batch, StepOutput]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: super().__init__() # initialize module @@ -223,7 +227,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None: self.steps_processed = 0 self.steps_before_stopping = steps_before_stopping - def predict_step(self, state: State, data: Batch) -> torch.Tensor: + def predict_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) diff --git a/tests/framework/test_train.py b/tests/framework/test_train.py index 931639782e..2e388b323a 100644 --- a/tests/framework/test_train.py +++ b/tests/framework/test_train.py @@ -179,7 +179,12 @@ def on_train_end(self, state: State, unit: TTrainUnit) -> None: ) def test_train_data_iter_step(self) -> None: - class TrainIteratorUnit(TrainUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]): + class TrainIteratorUnit( + TrainUnit[ + Iterator[Tuple[torch.Tensor, torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor], + ] + ): def __init__(self, input_dim: int) -> None: super().__init__() self.module = nn.Linear(input_dim, 2) @@ -263,9 +268,10 @@ def test_train_timing(self) -> None: Batch = Tuple[torch.Tensor, torch.Tensor] +StepOutput = torch.Tensor -class StopTrainUnit(TrainUnit[Batch]): +class StopTrainUnit(TrainUnit[Batch, StepOutput]): def __init__(self, input_dim: int, steps_before_stopping: int) -> None: super().__init__() # initialize module, loss_fn, & optimizer @@ -275,7 +281,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None: self.steps_processed = 0 self.steps_before_stopping = steps_before_stopping - def train_step(self, state: State, data: Batch) -> torch.Tensor: + def train_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) @@ -295,6 +301,3 @@ def train_step(self, state: State, data: Batch) -> torch.Tensor: self.steps_processed += 1 # pyre-fixme[7]: Expected `Tensor` but got `Tuple[typing.Any, typing.Any]`. return loss, outputs - - -Batch = Tuple[torch.Tensor, torch.Tensor] diff --git a/torchtnt/framework/_test_utils.py b/torchtnt/framework/_test_utils.py index a37b2fdfb9..40a88ab807 100644 --- a/torchtnt/framework/_test_utils.py +++ b/torchtnt/framework/_test_utils.py @@ -16,6 +16,8 @@ from torchtnt.utils.lr_scheduler import TLRScheduler Batch = Tuple[torch.Tensor, torch.Tensor] +StepOutput = Tuple[torch.Tensor, torch.Tensor] +PredictStepOutput = torch.Tensor def get_dummy_train_state(dataloader: Optional[Iterable[object]] = None) -> State: @@ -31,14 +33,14 @@ def get_dummy_train_state(dataloader: Optional[Iterable[object]] = None) -> Stat ) -class DummyEvalUnit(EvalUnit[Batch]): +class DummyEvalUnit(EvalUnit[Batch, StepOutput]): def __init__(self, input_dim: int) -> None: super().__init__() # initialize module & loss_fn self.module = nn.Linear(input_dim, 2) self.loss_fn = nn.CrossEntropyLoss() - def eval_step(self, state: State, data: Batch) -> Tuple[torch.Tensor, torch.Tensor]: + def eval_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) @@ -46,20 +48,20 @@ def eval_step(self, state: State, data: Batch) -> Tuple[torch.Tensor, torch.Tens return loss, outputs -class DummyPredictUnit(PredictUnit[Batch]): +class DummyPredictUnit(PredictUnit[Batch, PredictStepOutput]): def __init__(self, input_dim: int) -> None: super().__init__() # initialize module self.module = nn.Linear(input_dim, 2) - def predict_step(self, state: State, data: Batch) -> torch.Tensor: + def predict_step(self, state: State, data: Batch) -> PredictStepOutput: inputs, targets = data outputs = self.module(inputs) return outputs -class DummyTrainUnit(TrainUnit[Batch]): +class DummyTrainUnit(TrainUnit[Batch, StepOutput]): def __init__(self, input_dim: int) -> None: super().__init__() # initialize module, loss_fn, & optimizer @@ -67,9 +69,7 @@ def __init__(self, input_dim: int) -> None: self.loss_fn = nn.CrossEntropyLoss() self.optimizer = torch.optim.SGD(self.module.parameters(), lr=0.01) - def train_step( - self, state: State, data: Batch - ) -> Tuple[torch.Tensor, torch.Tensor]: + def train_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) @@ -82,7 +82,7 @@ def train_step( return loss, outputs -class DummyFitUnit(TrainUnit[Batch], EvalUnit[Batch]): +class DummyFitUnit(TrainUnit[Batch, StepOutput], EvalUnit[Batch, StepOutput]): def __init__(self, input_dim: int) -> None: super().__init__() # initialize module, loss_fn, & optimizer @@ -90,9 +90,7 @@ def __init__(self, input_dim: int) -> None: self.loss_fn = nn.CrossEntropyLoss() self.optimizer = torch.optim.SGD(self.module.parameters(), lr=0.01) - def train_step( - self, state: State, data: Batch - ) -> Tuple[torch.Tensor, torch.Tensor]: + def train_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) @@ -104,7 +102,7 @@ def train_step( return loss, outputs - def eval_step(self, state: State, data: Batch) -> Tuple[torch.Tensor, torch.Tensor]: + def eval_step(self, state: State, data: Batch) -> StepOutput: inputs, targets = data outputs = self.module(inputs) @@ -147,7 +145,7 @@ def generate_random_iterable_dataloader( ) -class DummyAutoUnit(AutoUnit[Batch]): +class DummyAutoUnit(AutoUnit[Batch, object]): def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, object]: inputs, targets = data outputs = self.module(inputs) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 292e1ecde9..df3d7006fb 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -7,7 +7,7 @@ import contextlib from abc import ABCMeta, abstractmethod -from typing import Any, Iterator, Optional, Tuple, TypeVar, Union +from typing import Iterator, Optional, Tuple, TypeVar, Union import torch from pyre_extensions import none_throws @@ -35,6 +35,7 @@ TData = TypeVar("TData") +TStepOutput = TypeVar("TStepOutput") class _ConfigureOptimizersCaller(ABCMeta): @@ -70,7 +71,7 @@ def __call__(self, *args, **kwargs): return x -class AutoPredictUnit(PredictUnit[TPredictData]): +class AutoPredictUnit(PredictUnit[TPredictData, TStepOutput]): def __init__( self, *, @@ -144,8 +145,7 @@ def __init__( self.detect_anomaly = detect_anomaly - # pyre-fixme[3]: Return annotation cannot be `Any`. - def predict_step(self, state: State, data: Iterator[TPredictData]) -> Any: + def predict_step(self, state: State, data: Iterator[TPredictData]) -> TStepOutput: with none_throws(state.predict_state).iteration_timer.time("data_wait_time"): batch = self._get_next_batch(state, data) @@ -170,8 +170,7 @@ def on_predict_step_end( state: State, data: TPredictData, step: int, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: TStepOutput, ) -> None: """ This will be called at the end of every ``predict_step`` before returning. The user can implement this method with code to update and log their metrics, @@ -261,9 +260,9 @@ def _prefetch_next_batch( class AutoUnit( - TrainUnit[TData], - EvalUnit[TData], - PredictUnit[TData], + TrainUnit[TData, Tuple[torch.Tensor, TStepOutput]], + EvalUnit[TData, Tuple[torch.Tensor, TStepOutput]], + PredictUnit[TData, TStepOutput], metaclass=_ConfigureOptimizersCaller, ): """ @@ -414,8 +413,9 @@ def configure_optimizers_and_lr_scheduler( ... @abstractmethod - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def compute_loss(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: + def compute_loss( + self, state: State, data: TData + ) -> Tuple[torch.Tensor, TStepOutput]: """ The user should implement this method with their loss computation. This will be called every ``train_step``/``eval_step``. @@ -509,10 +509,9 @@ def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData: return batch - # pyre-fixme[3]: Return annotation cannot contain `Any`. def train_step( self, state: State, data: Iterator[TData] - ) -> Tuple[torch.Tensor, Any]: + ) -> Tuple[torch.Tensor, TStepOutput]: # In auto unit they will not be exclusive since data fetching is done as # part of the training step with none_throws(state.train_state).iteration_timer.time("data_wait_time"): @@ -643,8 +642,7 @@ def on_train_step_end( data: TData, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: TStepOutput, ) -> None: """ This will be called at the end of every ``train_step`` before returning. The user can implement this method with code to update and log their metrics, @@ -700,8 +698,7 @@ def on_train_end(self, state: State) -> None: ): transfer_batch_norm_stats(swa_model, self.module) - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def eval_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: + def eval_step(self, state: State, data: TData) -> Tuple[torch.Tensor, TStepOutput]: with get_timing_context( state, f"{self.__class__.__name__}.move_data_to_device" ): @@ -726,8 +723,7 @@ def on_eval_step_end( data: TData, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: TStepOutput, ) -> None: """ This will be called at the end of every ``eval_step`` before returning. The user can implement this method with code to update and log their metrics, @@ -742,8 +738,7 @@ def on_eval_step_end( """ pass - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def predict_step(self, state: State, data: TData) -> Any: + def predict_step(self, state: State, data: TData) -> TStepOutput: with get_timing_context( state, f"{self.__class__.__name__}.move_data_to_device" ): @@ -766,8 +761,7 @@ def on_predict_step_end( state: State, data: TData, step: int, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: TStepOutput, ) -> None: """ This will be called at the end of every ``predict_step`` before returning. The user can implement this method with code to update and log their metrics, diff --git a/torchtnt/framework/unit.py b/torchtnt/framework/unit.py index ec266f33fd..c3e93a8112 100644 --- a/torchtnt/framework/unit.py +++ b/torchtnt/framework/unit.py @@ -180,9 +180,14 @@ def on_exception(self, state: State, exc: BaseException) -> None: TTrainData = TypeVar("TTrainData") TEvalData = TypeVar("TEvalData") TPredictData = TypeVar("TPredictData") +TTrainStepOutput = TypeVar("TStepOutput") +TEvalStepOutput = TypeVar("TEvalStepOutput") +TPredictStepOutput = TypeVar("TPredictStepOutput") -class TrainUnit(AppStateMixin, _OnExceptionMixin, Generic[TTrainData], ABC): +class TrainUnit( + AppStateMixin, _OnExceptionMixin, Generic[TTrainData, TTrainStepOutput], ABC +): """ The TrainUnit is an interface that can be used to organize your training logic. The core of it is the ``train_step`` which is an abstract method where you can define the code you want to run each iteration of the dataloader. @@ -248,8 +253,7 @@ def on_train_epoch_start(self, state: State) -> None: pass @abstractmethod - # pyre-fixme[3]: Return annotation cannot be `Any`. - def train_step(self, state: State, data: TTrainData) -> Any: + def train_step(self, state: State, data: TTrainData) -> TTrainStepOutput: """Core required method for user to implement. This method will be called at each iteration of the train dataloader, and can return any data the user wishes. @@ -276,7 +280,9 @@ def on_train_end(self, state: State) -> None: pass -class EvalUnit(AppStateMixin, _OnExceptionMixin, Generic[TEvalData], ABC): +class EvalUnit( + AppStateMixin, _OnExceptionMixin, Generic[TEvalData, TEvalStepOutput], ABC +): """ The EvalUnit is an interface that can be used to organize your evaluation logic. The core of it is the ``eval_step`` which is an abstract method where you can define the code you want to run each iteration of the dataloader. @@ -329,8 +335,7 @@ def on_eval_epoch_start(self, state: State) -> None: pass @abstractmethod - # pyre-fixme[3]: Return annotation cannot be `Any`. - def eval_step(self, state: State, data: TEvalData) -> Any: + def eval_step(self, state: State, data: TEvalData) -> TEvalStepOutput: """ Core required method for user to implement. This method will be called at each iteration of the eval dataloader, and can return any data the user wishes. @@ -362,7 +367,7 @@ def on_eval_end(self, state: State) -> None: class PredictUnit( AppStateMixin, _OnExceptionMixin, - Generic[TPredictData], + Generic[TPredictData, TPredictStepOutput], ABC, ): """ @@ -417,8 +422,7 @@ def on_predict_epoch_start(self, state: State) -> None: pass @abstractmethod - # pyre-fixme[3]: Return annotation cannot be `Any`. - def predict_step(self, state: State, data: TPredictData) -> Any: + def predict_step(self, state: State, data: TPredictData) -> TPredictStepOutput: """ Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. @@ -447,6 +451,6 @@ def on_predict_end(self, state: State) -> None: pass -TTrainUnit = TrainUnit[TTrainData] -TEvalUnit = EvalUnit[TEvalData] -TPredictUnit = PredictUnit[TPredictData] +TTrainUnit = TrainUnit[TTrainData, TTrainStepOutput] +TEvalUnit = EvalUnit[TEvalData, TEvalStepOutput] +TPredictUnit = PredictUnit[TPredictData, TPredictStepOutput]