Skip to content

Commit

Permalink
add StepOutput generic to Unit
Browse files Browse the repository at this point in the history
Summary: Add a second generic to `Unit`, denoting the `step_output` type.

Differential Revision: D49391712
  • Loading branch information
galrotem authored and facebook-github-bot committed Sep 26, 2023
1 parent cd2c4f8 commit 1a2f5a1
Show file tree
Hide file tree
Showing 16 changed files with 112 additions and 110 deletions.
14 changes: 7 additions & 7 deletions examples/auto_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


Batch = Tuple[torch.Tensor, torch.Tensor]
ModelStepOutput = torch.Tensor
NUM_PROCESSES = 2


Expand Down Expand Up @@ -60,7 +61,7 @@ def prepare_dataloader(
)


class MyUnit(AutoUnit[Batch]):
class MyUnit(AutoUnit[Batch, ModelStepOutput]):
# pyre-fixme[3]: Return type must be annotated.
def __init__(
self,
Expand Down Expand Up @@ -110,8 +111,9 @@ def configure_optimizers_and_lr_scheduler(
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
return optimizer, lr_scheduler

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, ModelStepOutput]:
inputs, targets = data
# convert targets to float Tensor for binary_cross_entropy_with_logits
targets = targets.float()
Expand All @@ -127,8 +129,7 @@ def on_train_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: ModelStepOutput,
) -> None:
_, targets = data
self.train_accuracy.update(outputs, targets)
Expand All @@ -143,8 +144,7 @@ def on_eval_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: ModelStepOutput,
) -> None:
_, targets = data
self.eval_accuracy.update(outputs, targets)
Expand Down
3 changes: 2 additions & 1 deletion examples/mingpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
logging.basicConfig(level=logging.INFO)

Batch = Tuple[torch.Tensor, torch.Tensor]
ModelStepOutput = torch.Tensor
PATH: str = parutil.get_file_path("data/input.txt", pkg=__package__)


Expand Down Expand Up @@ -59,7 +60,7 @@ def get_datasets(
return train_set, eval_set, dataset


class MinGPTUnit(AutoUnit[Batch]):
class MinGPTUnit(AutoUnit[Batch, ModelStepOutput]):
def __init__(
self,
tb_logger: TensorBoardLogger,
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchvision import datasets, transforms

Batch = Tuple[torch.Tensor, torch.Tensor]
ModelStepOutput = torch.Tensor


class Net(nn.Module):
Expand Down Expand Up @@ -56,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return output


class MyUnit(AutoUnit[Batch]):
class MyUnit(AutoUnit[Batch, ModelStepOutput]):
def __init__(
self,
*,
Expand Down
2 changes: 1 addition & 1 deletion examples/torchdata_train_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def prepare_dataloader(
return dataloader


class MyTrainUnit(TrainUnit[Batch]):
class MyTrainUnit(TrainUnit[Batch, None]):
def __init__(
self,
module: torch.nn.Module,
Expand Down
11 changes: 4 additions & 7 deletions examples/torchrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
return parser.parse_args(argv)


Batch = Tuple[torch.Tensor, torch.Tensor]
Batch = Iterator[Tuple[torch.Tensor, torch.Tensor]]


class MyUnit(TrainUnit[Batch], EvalUnit[Batch]):
class MyUnit(TrainUnit[Batch, None], EvalUnit[Batch, None]):
def __init__(
self,
module: torch.nn.Module,
Expand All @@ -205,9 +205,7 @@ def __init__(
self.tb_logger = tb_logger
self.log_every_n_steps = log_every_n_steps

# pyre-fixme[14]: `train_step` overrides method defined in `TrainUnit`
# inconsistently.
def train_step(self, state: State, data: Iterator[Batch]) -> None:
def train_step(self, state: State, data: Batch) -> None:
step = self.train_progress.num_steps_completed
loss, logits, labels = self.pipeline.progress(data)
preds = torch.sigmoid(logits)
Expand All @@ -222,8 +220,7 @@ def on_train_epoch_end(self, state: State) -> None:
# reset the metric every epoch
self.train_auroc.reset()

# pyre-fixme[14]: `eval_step` overrides method defined in `EvalUnit` inconsistently.
def eval_step(self, state: State, data: Iterator[Batch]) -> None:
def eval_step(self, state: State, data: Batch) -> None:
step = self.eval_progress.num_steps_completed
loss, _, _ = self.pipeline.progress(data)
if step % self.log_every_n_steps == 0:
Expand Down
2 changes: 1 addition & 1 deletion examples/train_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def prepare_dataloader(
)


class MyTrainUnit(TrainUnit[Batch]):
class MyTrainUnit(TrainUnit[Batch, None]):
def __init__(
self,
module: torch.nn.Module,
Expand Down
6 changes: 3 additions & 3 deletions tests/framework/callbacks/test_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchtnt.framework.callbacks.base_csv_writer import BaseCSVWriter
from torchtnt.framework.predict import predict
from torchtnt.framework.state import State
from torchtnt.framework.unit import PredictUnit, TPredictData
from torchtnt.framework.unit import TPredictUnit

_HEADER_ROW = ["output"]
_FILENAME = "test_csv_writer.csv"
Expand All @@ -23,7 +23,7 @@ class CustomCSVWriter(BaseCSVWriter):
def get_step_output_rows(
self,
state: State,
unit: PredictUnit[TPredictData],
unit: TPredictUnit,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
step_output: Any,
) -> Union[List[str], List[List[str]]]:
Expand All @@ -34,7 +34,7 @@ class CustomCSVWriterSingleRow(BaseCSVWriter):
def get_step_output_rows(
self,
state: State,
unit: PredictUnit[TPredictData],
unit: TPredictUnit,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
step_output: Any,
) -> Union[List[str], List[List[str]]]:
Expand Down
2 changes: 1 addition & 1 deletion tests/framework/callbacks/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Batch = Tuple[torch.Tensor, torch.Tensor]


class DummyTrainExceptUnit(TrainUnit[Batch]):
class DummyTrainExceptUnit(TrainUnit[Batch, None]):
def __init__(self, input_dim: int) -> None:
super().__init__()

Expand Down
29 changes: 14 additions & 15 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,14 +738,15 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
Batch = Tuple[torch.Tensor, torch.Tensor]


class DummyLRSchedulerAutoUnit(AutoUnit[Batch]):
class DummyLRSchedulerAutoUnit(AutoUnit[Batch, torch.Tensor]):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand All @@ -760,15 +761,16 @@ def configure_optimizers_and_lr_scheduler(
return my_optimizer, my_lr_scheduler


class DummyComplexAutoUnit(AutoUnit[Batch]):
class DummyComplexAutoUnit(AutoUnit[Batch, torch.Tensor]):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, lr: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr = lr

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand All @@ -783,7 +785,7 @@ def configure_optimizers_and_lr_scheduler(
return my_optimizer, my_lr_scheduler


class LastBatchAutoUnit(AutoUnit[Batch]):
class LastBatchAutoUnit(AutoUnit[Batch, torch.Tensor]):
def __init__(self, module: torch.nn.Module, expected_steps_per_epoch: int) -> None:
super().__init__(module=module)
self.expected_steps_per_epoch = expected_steps_per_epoch
Expand Down Expand Up @@ -812,7 +814,7 @@ def configure_optimizers_and_lr_scheduler(
return my_optimizer, my_lr_scheduler


class TimingAutoUnit(AutoUnit[Batch]):
class TimingAutoUnit(AutoUnit[Batch, torch.Tensor]):
def __init__(self, module: torch.nn.Module) -> None:
super().__init__(module=module)
self.loss_fn = torch.nn.CrossEntropyLoss()
Expand All @@ -839,8 +841,7 @@ def on_train_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
assert state.train_state
if self.train_progress.num_steps_completed_in_epoch == 1:
Expand Down Expand Up @@ -871,8 +872,7 @@ def on_eval_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
if self.eval_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
Expand Down Expand Up @@ -918,7 +918,7 @@ def on_predict_step_end(
tc.assertNotIn("TimingAutoUnit.predict_step", recorded_timer_keys)


class TimingAutoPredictUnit(AutoPredictUnit[Batch]):
class TimingAutoPredictUnit(AutoPredictUnit[Batch, torch.Tensor]):
def __init__(self, module: torch.nn.Module) -> None:
super().__init__(module=module)
self.loss_fn = torch.nn.CrossEntropyLoss()
Expand All @@ -944,8 +944,7 @@ def on_predict_step_end(
state: State,
data: TPredictData,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
if self.predict_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
Expand Down
20 changes: 12 additions & 8 deletions tests/framework/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ def test_evaluate_stop(self) -> None:
self.assertEqual(my_unit.steps_processed, steps_before_stopping)

def test_evaluate_data_iter_step(self) -> None:
class EvalIteratorUnit(EvalUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]):
class EvalIteratorUnit(
EvalUnit[
Iterator[Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
]
):
def __init__(self, input_dim: int) -> None:
super().__init__()
self.module = nn.Linear(input_dim, 2)
Expand Down Expand Up @@ -204,7 +209,11 @@ def test_evaluate_timing(self) -> None:
self.assertIn("evaluate.next(data_iter)", timer.recorded_durations.keys())


class StopEvalUnit(EvalUnit[Tuple[torch.Tensor, torch.Tensor]]):
Batch = Tuple[torch.Tensor, torch.Tensor]
StepOutput = Tuple[torch.Tensor, torch.Tensor]


class StopEvalUnit(EvalUnit[Batch, StepOutput]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
super().__init__()
# initialize module & loss_fn
Expand All @@ -213,9 +222,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
self.steps_processed = 0
self.steps_before_stopping = steps_before_stopping

def eval_step(
self, state: State, data: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
def eval_step(self, state: State, data: Batch) -> StepOutput:

inputs, targets = data

Expand All @@ -231,6 +238,3 @@ def eval_step(

self.steps_processed += 1
return loss, outputs


Batch = Tuple[torch.Tensor, torch.Tensor]
11 changes: 4 additions & 7 deletions tests/framework/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def test_fit_evaluate_every_n_steps(self) -> None:

def test_fit_stop(self) -> None:
Batch = Tuple[torch.Tensor, torch.Tensor]
StepOutput = Tuple[torch.Tensor, torch.Tensor]

class FitStop(TrainUnit[Batch], EvalUnit[Batch]):
class FitStop(TrainUnit[Batch, StepOutput], EvalUnit[Batch, StepOutput]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
super().__init__()
# initialize module, loss_fn, & optimizer
Expand All @@ -134,9 +135,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
self.steps_processed = 0
self.steps_before_stopping = steps_before_stopping

def train_step(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
def train_step(self, state: State, data: Batch) -> StepOutput:
inputs, targets = data

outputs = self.module(inputs)
Expand All @@ -156,9 +155,7 @@ def train_step(
self.steps_processed += 1
return loss, outputs

def eval_step(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
def eval_step(self, state: State, data: Batch) -> StepOutput:
inputs, targets = data
outputs = self.module(inputs)
loss = self.loss_fn(outputs, targets)
Expand Down
10 changes: 7 additions & 3 deletions tests/framework/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def on_predict_end(self, state: State, unit: TPredictUnit) -> None:

def test_predict_data_iter_step(self) -> None:
class PredictIteratorUnit(
PredictUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]
PredictUnit[
Iterator[Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
]
):
def __init__(self, input_dim: int) -> None:
super().__init__()
Expand Down Expand Up @@ -213,17 +216,18 @@ def test_predict_timing(self) -> None:


Batch = Tuple[torch.Tensor, torch.Tensor]
StepOutput = torch.Tensor


class StopPredictUnit(PredictUnit[Batch]):
class StopPredictUnit(PredictUnit[Batch, StepOutput]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
super().__init__()
# initialize module
self.module = nn.Linear(input_dim, 2)
self.steps_processed = 0
self.steps_before_stopping = steps_before_stopping

def predict_step(self, state: State, data: Batch) -> torch.Tensor:
def predict_step(self, state: State, data: Batch) -> StepOutput:
inputs, targets = data

outputs = self.module(inputs)
Expand Down
Loading

0 comments on commit 1a2f5a1

Please sign in to comment.