Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add StepOutput generic to Unit #553

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/auto_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


Batch = Tuple[torch.Tensor, torch.Tensor]
ModelStepOutput = torch.Tensor
NUM_PROCESSES = 2


Expand Down Expand Up @@ -60,7 +61,7 @@ def prepare_dataloader(
)


class MyUnit(AutoUnit[Batch]):
class MyUnit(AutoUnit[Batch, ModelStepOutput]):
# pyre-fixme[3]: Return type must be annotated.
def __init__(
self,
Expand Down Expand Up @@ -110,8 +111,9 @@ def configure_optimizers_and_lr_scheduler(
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
return optimizer, lr_scheduler

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, ModelStepOutput]:
inputs, targets = data
# convert targets to float Tensor for binary_cross_entropy_with_logits
targets = targets.float()
Expand All @@ -127,8 +129,7 @@ def on_train_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: ModelStepOutput,
) -> None:
_, targets = data
self.train_accuracy.update(outputs, targets)
Expand All @@ -143,8 +144,7 @@ def on_eval_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: ModelStepOutput,
) -> None:
_, targets = data
self.eval_accuracy.update(outputs, targets)
Expand Down
3 changes: 2 additions & 1 deletion examples/mingpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
logging.basicConfig(level=logging.INFO)

Batch = Tuple[torch.Tensor, torch.Tensor]
ModelStepOutput = torch.Tensor
PATH: str = parutil.get_file_path("data/input.txt", pkg=__package__)


Expand Down Expand Up @@ -59,7 +60,7 @@ def get_datasets(
return train_set, eval_set, dataset


class MinGPTUnit(AutoUnit[Batch]):
class MinGPTUnit(AutoUnit[Batch, ModelStepOutput]):
def __init__(
self,
tb_logger: TensorBoardLogger,
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchvision import datasets, transforms

Batch = Tuple[torch.Tensor, torch.Tensor]
ModelStepOutput = torch.Tensor


class Net(nn.Module):
Expand Down Expand Up @@ -56,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return output


class MyUnit(AutoUnit[Batch]):
class MyUnit(AutoUnit[Batch, ModelStepOutput]):
def __init__(
self,
*,
Expand Down
2 changes: 1 addition & 1 deletion examples/torchdata_train_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def prepare_dataloader(
return dataloader


class MyTrainUnit(TrainUnit[Batch]):
class MyTrainUnit(TrainUnit[Batch, None]):
def __init__(
self,
module: torch.nn.Module,
Expand Down
11 changes: 4 additions & 7 deletions examples/torchrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
return parser.parse_args(argv)


Batch = Tuple[torch.Tensor, torch.Tensor]
Batch = Iterator[Tuple[torch.Tensor, torch.Tensor]]


class MyUnit(TrainUnit[Batch], EvalUnit[Batch]):
class MyUnit(TrainUnit[Batch, None], EvalUnit[Batch, None]):
def __init__(
self,
module: torch.nn.Module,
Expand All @@ -205,9 +205,7 @@ def __init__(
self.tb_logger = tb_logger
self.log_every_n_steps = log_every_n_steps

# pyre-fixme[14]: `train_step` overrides method defined in `TrainUnit`
# inconsistently.
def train_step(self, state: State, data: Iterator[Batch]) -> None:
def train_step(self, state: State, data: Batch) -> None:
step = self.train_progress.num_steps_completed
loss, logits, labels = self.pipeline.progress(data)
preds = torch.sigmoid(logits)
Expand All @@ -222,8 +220,7 @@ def on_train_epoch_end(self, state: State) -> None:
# reset the metric every epoch
self.train_auroc.reset()

# pyre-fixme[14]: `eval_step` overrides method defined in `EvalUnit` inconsistently.
def eval_step(self, state: State, data: Iterator[Batch]) -> None:
def eval_step(self, state: State, data: Batch) -> None:
step = self.eval_progress.num_steps_completed
loss, _, _ = self.pipeline.progress(data)
if step % self.log_every_n_steps == 0:
Expand Down
2 changes: 1 addition & 1 deletion examples/train_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def prepare_dataloader(
)


class MyTrainUnit(TrainUnit[Batch]):
class MyTrainUnit(TrainUnit[Batch, None]):
def __init__(
self,
module: torch.nn.Module,
Expand Down
6 changes: 3 additions & 3 deletions tests/framework/callbacks/test_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchtnt.framework.callbacks.base_csv_writer import BaseCSVWriter
from torchtnt.framework.predict import predict
from torchtnt.framework.state import State
from torchtnt.framework.unit import PredictUnit, TPredictData
from torchtnt.framework.unit import TPredictUnit

_HEADER_ROW = ["output"]
_FILENAME = "test_csv_writer.csv"
Expand All @@ -23,7 +23,7 @@ class CustomCSVWriter(BaseCSVWriter):
def get_step_output_rows(
self,
state: State,
unit: PredictUnit[TPredictData],
unit: TPredictUnit,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
step_output: Any,
) -> Union[List[str], List[List[str]]]:
Expand All @@ -34,7 +34,7 @@ class CustomCSVWriterSingleRow(BaseCSVWriter):
def get_step_output_rows(
self,
state: State,
unit: PredictUnit[TPredictData],
unit: TPredictUnit,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
step_output: Any,
) -> Union[List[str], List[List[str]]]:
Expand Down
2 changes: 1 addition & 1 deletion tests/framework/callbacks/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Batch = Tuple[torch.Tensor, torch.Tensor]


class DummyTrainExceptUnit(TrainUnit[Batch]):
class DummyTrainExceptUnit(TrainUnit[Batch, None]):
def __init__(self, input_dim: int) -> None:
super().__init__()

Expand Down
29 changes: 14 additions & 15 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,14 +738,15 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
Batch = Tuple[torch.Tensor, torch.Tensor]


class DummyLRSchedulerAutoUnit(AutoUnit[Batch]):
class DummyLRSchedulerAutoUnit(AutoUnit[Batch, torch.Tensor]):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand All @@ -760,15 +761,16 @@ def configure_optimizers_and_lr_scheduler(
return my_optimizer, my_lr_scheduler


class DummyComplexAutoUnit(AutoUnit[Batch]):
class DummyComplexAutoUnit(AutoUnit[Batch, torch.Tensor]):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, lr: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr = lr

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand All @@ -783,7 +785,7 @@ def configure_optimizers_and_lr_scheduler(
return my_optimizer, my_lr_scheduler


class LastBatchAutoUnit(AutoUnit[Batch]):
class LastBatchAutoUnit(AutoUnit[Batch, torch.Tensor]):
def __init__(self, module: torch.nn.Module, expected_steps_per_epoch: int) -> None:
super().__init__(module=module)
self.expected_steps_per_epoch = expected_steps_per_epoch
Expand Down Expand Up @@ -812,7 +814,7 @@ def configure_optimizers_and_lr_scheduler(
return my_optimizer, my_lr_scheduler


class TimingAutoUnit(AutoUnit[Batch]):
class TimingAutoUnit(AutoUnit[Batch, torch.Tensor]):
def __init__(self, module: torch.nn.Module) -> None:
super().__init__(module=module)
self.loss_fn = torch.nn.CrossEntropyLoss()
Expand All @@ -839,8 +841,7 @@ def on_train_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
assert state.train_state
if self.train_progress.num_steps_completed_in_epoch == 1:
Expand Down Expand Up @@ -871,8 +872,7 @@ def on_eval_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
if self.eval_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
Expand Down Expand Up @@ -918,7 +918,7 @@ def on_predict_step_end(
tc.assertNotIn("TimingAutoUnit.predict_step", recorded_timer_keys)


class TimingAutoPredictUnit(AutoPredictUnit[Batch]):
class TimingAutoPredictUnit(AutoPredictUnit[Batch, torch.Tensor]):
def __init__(self, module: torch.nn.Module) -> None:
super().__init__(module=module)
self.loss_fn = torch.nn.CrossEntropyLoss()
Expand All @@ -944,8 +944,7 @@ def on_predict_step_end(
state: State,
data: TPredictData,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
if self.predict_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
Expand Down
20 changes: 12 additions & 8 deletions tests/framework/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ def test_evaluate_stop(self) -> None:
self.assertEqual(my_unit.steps_processed, steps_before_stopping)

def test_evaluate_data_iter_step(self) -> None:
class EvalIteratorUnit(EvalUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]):
class EvalIteratorUnit(
EvalUnit[
Iterator[Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
]
):
def __init__(self, input_dim: int) -> None:
super().__init__()
self.module = nn.Linear(input_dim, 2)
Expand Down Expand Up @@ -204,7 +209,11 @@ def test_evaluate_timing(self) -> None:
self.assertIn("evaluate.next(data_iter)", timer.recorded_durations.keys())


class StopEvalUnit(EvalUnit[Tuple[torch.Tensor, torch.Tensor]]):
Batch = Tuple[torch.Tensor, torch.Tensor]
StepOutput = Tuple[torch.Tensor, torch.Tensor]


class StopEvalUnit(EvalUnit[Batch, StepOutput]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
super().__init__()
# initialize module & loss_fn
Expand All @@ -213,9 +222,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
self.steps_processed = 0
self.steps_before_stopping = steps_before_stopping

def eval_step(
self, state: State, data: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
def eval_step(self, state: State, data: Batch) -> StepOutput:

inputs, targets = data

Expand All @@ -231,6 +238,3 @@ def eval_step(

self.steps_processed += 1
return loss, outputs


Batch = Tuple[torch.Tensor, torch.Tensor]
11 changes: 4 additions & 7 deletions tests/framework/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def test_fit_evaluate_every_n_steps(self) -> None:

def test_fit_stop(self) -> None:
Batch = Tuple[torch.Tensor, torch.Tensor]
StepOutput = Tuple[torch.Tensor, torch.Tensor]

class FitStop(TrainUnit[Batch], EvalUnit[Batch]):
class FitStop(TrainUnit[Batch, StepOutput], EvalUnit[Batch, StepOutput]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
super().__init__()
# initialize module, loss_fn, & optimizer
Expand All @@ -134,9 +135,7 @@ def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
self.steps_processed = 0
self.steps_before_stopping = steps_before_stopping

def train_step(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
def train_step(self, state: State, data: Batch) -> StepOutput:
inputs, targets = data

outputs = self.module(inputs)
Expand All @@ -156,9 +155,7 @@ def train_step(
self.steps_processed += 1
return loss, outputs

def eval_step(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
def eval_step(self, state: State, data: Batch) -> StepOutput:
inputs, targets = data
outputs = self.module(inputs)
loss = self.loss_fn(outputs, targets)
Expand Down
10 changes: 7 additions & 3 deletions tests/framework/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def on_predict_end(self, state: State, unit: TPredictUnit) -> None:

def test_predict_data_iter_step(self) -> None:
class PredictIteratorUnit(
PredictUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]
PredictUnit[
Iterator[Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
]
):
def __init__(self, input_dim: int) -> None:
super().__init__()
Expand Down Expand Up @@ -213,17 +216,18 @@ def test_predict_timing(self) -> None:


Batch = Tuple[torch.Tensor, torch.Tensor]
StepOutput = torch.Tensor


class StopPredictUnit(PredictUnit[Batch]):
class StopPredictUnit(PredictUnit[Batch, StepOutput]):
def __init__(self, input_dim: int, steps_before_stopping: int) -> None:
super().__init__()
# initialize module
self.module = nn.Linear(input_dim, 2)
self.steps_processed = 0
self.steps_before_stopping = steps_before_stopping

def predict_step(self, state: State, data: Batch) -> torch.Tensor:
def predict_step(self, state: State, data: Batch) -> StepOutput:
inputs, targets = data

outputs = self.module(inputs)
Expand Down
Loading
Loading