Skip to content

Commit

Permalink
Refactor System - no more passing mode around, improved _step readabi…
Browse files Browse the repository at this point in the history
…lity, return all data from the _step no matter what mode. Add loss, metric, step, and epoch to Data enum. Switch str, Enum to StrEnum. Tuner.scale_batch_size() does not work since we switched to dataloaders in config - it needs read/write access to the batch size, figuring out the fix.
  • Loading branch information
ibro45 committed Jan 19, 2025
1 parent 88d731a commit e6b71d9
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 67 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,4 @@ projects/*
.aider*
test_dir/
*.code-workspace
.scale_batch_size*
8 changes: 4 additions & 4 deletions lighter/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,14 @@ def cli():
class Commands:
def __init__(self):
for stage in Stage:
setattr(self, stage.value, self._make_command(stage))
setattr(self, stage, self._make_command(stage))

def _make_command(self, stage: Stage):
def command(config: str, **config_overrides: Any):
return runner.run(stage=stage.value, config=config, **config_overrides)
return runner.run(stage=stage, config=config, **config_overrides)

command.__name__ = stage.value
command.__doc__ = f"Run the '{stage.value}' stage."
command.__name__ = stage
command.__doc__ = f"Run the '{stage}' stage."
return command

fire.Fire(Commands())
152 changes: 93 additions & 59 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
including the model, optimizer, datasets, and more. It extends PyTorch Lightning's LightningModule.
"""

from typing import Any, Callable, List, Tuple
from typing import Any, Callable

from dataclasses import asdict
from functools import partial

import pytorch_lightning as pl
from torch import Tensor
Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
criterion: Callable | None = None,
metrics: dict[str, Metric | List[Metric] | dict[str, Metric]] | None = None,
metrics: dict[str, Metric | list[Metric] | dict[str, Metric]] | None = None,
adapters: dict[str, Callable] | None = None,
inferer: Callable | None = None,
) -> None:
Expand All @@ -67,24 +66,28 @@ def __init__(
# Register metrics to move them to the appropriate device. ModuleDict not used because 'train' is a reserved key.
for mode, metric in asdict(self.metrics).items():
if isinstance(metric, Module):
self.add_module(f"metrics_{mode}", metric)
self.add_module(f"{Data.METRICS}_{mode}", metric)

# Dataloader and step LightningModule methods
dataloaders = DataLoaders(**(dataloaders or {}))
if dataloaders.train is not None:
self.train_dataloader = lambda: dataloaders.train
self.training_step = partial(self._step, mode=Mode.TRAIN)
self.training_step = self._step
if dataloaders.val is not None:
self.val_dataloader = lambda: dataloaders.val
self.validation_step = partial(self._step, mode=Mode.VAL)
self.validation_step = self._step
if dataloaders.test is not None:
self.test_dataloader = lambda: dataloaders.test
self.test_step = partial(self._step, mode=Mode.TEST)
self.test_step = self._step
if dataloaders.predict is not None:
self.predict_dataloader = lambda: dataloaders.predict
self.predict_step = partial(self._step, mode=Mode.PREDICT)
self.predict_step = self._step

def forward(self, input: Tensor | List[Tensor] | Tuple[Tensor] | dict[str, Tensor]) -> Any:
# Keep track of the current mode and its batch size. Overriden in on_{train,validation,test,predict}_start.
self.mode = None
self.batch_size = 0

def forward(self, input: Any) -> Any: # pylint: disable=arguments-differ
"""
Forward pass through the model. Supports multi-input models.
Expand All @@ -98,13 +101,17 @@ def forward(self, input: Tensor | List[Tensor] | Tuple[Tensor] | dict[str, Tenso
# Keyword arguments to pass to the forward method
kwargs = {}
# Add `epoch` argument if forward accepts it
if hasarg(self.model.forward, "epoch"):
kwargs["epoch"] = self.current_epoch
if hasarg(self.model.forward, Data.EPOCH):
kwargs[Data.EPOCH] = self.current_epoch
# Add `step` argument if forward accepts it
if hasarg(self.model.forward, "step"):
kwargs["step"] = self.global_step
if hasarg(self.model.forward, Data.STEP):
kwargs[Data.STEP] = self.global_step

return self.model(input, **kwargs)
# Use the inferer if specified and in val/test/predict mode
if self.inferer and self.mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]:
return self.inferer(input, self.model, **kwargs)
else:
return self.model(input, **kwargs)

def configure_optimizers(self) -> dict:
"""
Expand All @@ -120,74 +127,59 @@ def configure_optimizers(self) -> dict:
else:
return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler}

def _step(self, batch: dict, batch_idx: int, mode: str) -> dict[str, Any] | Any:
def _step(self, batch: dict, batch_idx: int) -> dict[str, Any] | Any:
"""
Performs a step in the specified mode, processing the batch and calculating loss and metrics.
Args:
batch: The batch of data.
batch_idx: The index of the batch.
mode: The mode of operation (train, val, test, predict).
Returns:
dict or Any: For predict step, returns prediction only. For other steps,
returns dict with loss, metrics, input, target, pred, and identifier. Loss is None
for test step, metrics is None if unspecified.
"""
adapters = getattr(self.adapters, mode)
input, target, identifier = self._prepare_batch(batch)
pred = self.forward(input)

# Batch adapter formats the batch into the required format
input, target, identifier = adapters.batch(batch)
loss = self._calculate_loss(input, target, pred)
metrics = self._calculate_metrics(input, target, pred)

# Forward
if self.inferer and mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]:
pred = self.inferer(input, self)
else:
pred = self(input)
self._log_stats(loss=loss, metrics=metrics, batch_idx=batch_idx)
output = self._prepare_output(input, target, pred, loss, metrics, identifier)
return output

# Predict mode stops here.
if mode == Mode.PREDICT:
# Logging adapter formats data for logging.
input, target, pred = adapters.logging(input, target, pred)
# Return data for callbacks (e.g. logging or writing).
return {Data.IDENTIFIER: identifier, Data.INPUT: input, Data.TARGET: target, Data.PRED: pred}
def _prepare_batch(self, batch: dict) -> tuple[Any, Any, Any]:
adapters = getattr(self.adapters, self.mode)
input, target, identifier = adapters.batch(batch)
return input, target, identifier

# Calculate the loss.
def _calculate_loss(self, input: Any, target: Any, pred: Any) -> Tensor | dict[str, Tensor] | None:
adapters = getattr(self.adapters, self.mode)
loss = None
if mode in [Mode.TRAIN, Mode.VAL]:
if self.mode in [Mode.TRAIN, Mode.VAL]:
loss = adapters.criterion(self.criterion, input=input, target=target, pred=pred)
if isinstance(loss, dict) and "total" not in loss:
raise ValueError(
"The loss dictionary must include a 'total' key that combines all sublosses. "
"Example: {'total': combined_loss, 'subloss1': loss1, ...}"
)
return loss

# Calculate the metrics. None if not specified.
metrics = getattr(self.metrics, mode)
def _calculate_metrics(self, input: Any, target: Any, pred: Any) -> Any | None:
adapters = getattr(self.adapters, self.mode)
metrics = getattr(self.metrics, self.mode)
if metrics is not None:
metrics = adapters.metrics(metrics, input=input, target=target, pred=pred)
return metrics

self._log_stats(loss=loss, metrics=metrics, mode=mode, batch_idx=batch_idx)

# Logging adapter formats data for logging.
input, target, pred = adapters.logging(input, target, pred)
# Return loss for backprop and data for callbacks (e.g. logging or writing).
return {
"loss": loss["total"] if isinstance(loss, dict) else loss,
Data.IDENTIFIER: identifier,
Data.INPUT: input,
Data.TARGET: target,
Data.PRED: pred,
"metrics": metrics,
}

def _log_stats(self, loss: Tensor | dict[str, Tensor], metrics: MetricCollection, mode: str, batch_idx: int) -> None:
def _log_stats(self, loss: Tensor | dict[str, Tensor], metrics: MetricCollection, batch_idx: int) -> None:
"""
Logs the loss, metrics, and optimizer statistics.
Args:
loss: The calculated loss.
metrics: The calculated metrics.
mode: The mode of operation (train, val, test, predict).
batch_idx: The index of the batch.
"""
if self.trainer.logger is None:
Expand All @@ -196,22 +188,22 @@ def _log_stats(self, loss: Tensor | dict[str, Tensor], metrics: MetricCollection
# Loss
if loss is not None:
if not isinstance(loss, dict):
self._log(f"{mode}/loss/step", loss, on_step=True)
self._log(f"{mode}/loss/epoch", loss, on_epoch=True)
self._log(f"{self.mode}/{Data.LOSS}/{Data.STEP}", loss, on_step=True)
self._log(f"{self.mode}/{Data.LOSS}/{Data.EPOCH}", loss, on_epoch=True)
else:
for name, subloss in loss.items():
self._log(f"{mode}/loss/{name}/step", subloss, on_step=True)
self._log(f"{mode}/loss/{name}/epoch", subloss, on_epoch=True)
self._log(f"{self.mode}/{Data.LOSS}/{name}/{Data.STEP}", subloss, on_step=True)
self._log(f"{self.mode}/{Data.LOSS}/{name}/{Data.EPOCH}", subloss, on_epoch=True)
# Metrics
if metrics is not None:
for name, metric in metrics.items():
self._log(f"{mode}/metrics/{name}/step", metric, on_step=True)
self._log(f"{mode}/metrics/{name}/epoch", metric, on_epoch=True)
self._log(f"{self.mode}/{Data.METRICS}/{name}/{Data.STEP}", metric, on_step=True)
self._log(f"{self.mode}/{Data.METRICS}/{name}/{Data.EPOCH}", metric, on_epoch=True)

# Optimizer's lr, momentum, beta. Logged in train mode and once per epoch.
if mode == Mode.TRAIN and batch_idx == 0:
if self.mode == Mode.TRAIN and batch_idx == 0:
for name, optimizer_stat in get_optimizer_stats(self.optimizer).items():
self._log(f"{mode}/{name}", optimizer_stat, on_epoch=True)
self._log(f"{self.mode}/{name}", optimizer_stat, on_epoch=True)

def _log(self, name: str, value: Any, on_step: bool = False, on_epoch: bool = False) -> None:
"""Log a key, value pair. Syncs across distributed nodes if `on_epoch` is True.
Expand All @@ -224,6 +216,48 @@ def _log(self, name: str, value: Any, on_step: bool = False, on_epoch: bool = Fa
"""
self.log(name, value, logger=True, batch_size=self.batch_size, on_step=on_step, on_epoch=on_epoch, sync_dist=on_epoch)

def _prepare_output(
self,
input: Any,
target: Any,
pred: Any,
loss: Tensor | dict[str, Tensor] | None,
metrics: Any | None,
identifier: Any,
) -> dict[str, Any]:
adapters = getattr(self.adapters, self.mode)
input, target, pred = adapters.logging(input, target, pred)
return {
Data.IDENTIFIER: identifier,
Data.INPUT: input,
Data.TARGET: target,
Data.PRED: pred,
Data.LOSS: loss,
Data.METRICS: metrics,
Data.STEP: self.global_step,
Data.EPOCH: self.current_epoch,
}

def on_train_start(self) -> None:
"""Called when the train begins."""
self.mode = Mode.TRAIN
self.batch_size = self.train_dataloader().batch_size

def on_validation_start(self) -> None:
"""Called when the validation loop begins."""
self.mode = Mode.VAL
self.batch_size = self.val_dataloader().batch_size

def on_test_start(self) -> None:
"""Called when the test begins."""
self.mode = Mode.TEST
self.batch_size = self.test_dataloader().batch_size

def on_predict_start(self) -> None:
"""Called when the prediction begins."""
self.mode = Mode.PREDICT
self.batch_size = self.predict_dataloader().batch_size

@property
def learning_rate(self) -> float:
"""
Expand All @@ -237,7 +271,7 @@ def learning_rate(self) -> float:
return self.optimizer.param_groups[0]["lr"]

@learning_rate.setter
def learning_rate(self, value) -> None:
def learning_rate(self, value: float) -> None:
"""
Sets the learning rate of the optimizer.
Expand Down
12 changes: 8 additions & 4 deletions lighter/utils/types/enums.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from enum import Enum
from enum import StrEnum


class Data(str, Enum):
class Data(StrEnum):
IDENTIFIER = "identifier"
INPUT = "input"
TARGET = "target"
PRED = "pred"
LOSS = "loss"
METRICS = "metrics"
STEP = "step"
EPOCH = "epoch"


class Stage(str, Enum):
class Stage(StrEnum):
FIT = "fit"
VALIDATE = "validate"
TEST = "test"
Expand All @@ -17,7 +21,7 @@ class Stage(str, Enum):
SCALE_BATCH_SIZE = "scale_batch_size"


class Mode(str, Enum):
class Mode(StrEnum):
TRAIN = "train"
VAL = "val"
TEST = "test"
Expand Down

0 comments on commit e6b71d9

Please sign in to comment.