Skip to content

Commit

Permalink
Reorder the System methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Jan 19, 2025
1 parent e6b71d9 commit 7d9f46f
Showing 1 changed file with 60 additions and 60 deletions.
120 changes: 60 additions & 60 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,46 +87,6 @@ def __init__(
self.mode = None
self.batch_size = 0

def forward(self, input: Any) -> Any: # pylint: disable=arguments-differ
"""
Forward pass through the model. Supports multi-input models.
Args:
input: The input data.
Returns:
Any: The model's output.
"""

# Keyword arguments to pass to the forward method
kwargs = {}
# Add `epoch` argument if forward accepts it
if hasarg(self.model.forward, Data.EPOCH):
kwargs[Data.EPOCH] = self.current_epoch
# Add `step` argument if forward accepts it
if hasarg(self.model.forward, Data.STEP):
kwargs[Data.STEP] = self.global_step

# Use the inferer if specified and in val/test/predict mode
if self.inferer and self.mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]:
return self.inferer(input, self.model, **kwargs)
else:
return self.model(input, **kwargs)

def configure_optimizers(self) -> dict:
"""
Configures the optimizers and learning rate schedulers.
Returns:
dict: A dictionary containing the optimizer and scheduler.
"""
if self.optimizer is None:
raise ValueError("Please specify 'system.optimizer' in the config.")
if self.scheduler is None:
return {"optimizer": self.optimizer}
else:
return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler}

def _step(self, batch: dict, batch_idx: int) -> dict[str, Any] | Any:
"""
Performs a step in the specified mode, processing the batch and calculating loss and metrics.
Expand Down Expand Up @@ -154,6 +114,32 @@ def _prepare_batch(self, batch: dict) -> tuple[Any, Any, Any]:
input, target, identifier = adapters.batch(batch)
return input, target, identifier

def forward(self, input: Any) -> Any: # pylint: disable=arguments-differ
"""
Forward pass through the model. Supports multi-input models.
Args:
input: The input data.
Returns:
Any: The model's output.
"""

# Keyword arguments to pass to the forward method
kwargs = {}
# Add `epoch` argument if forward accepts it
if hasarg(self.model.forward, Data.EPOCH):
kwargs[Data.EPOCH] = self.current_epoch
# Add `step` argument if forward accepts it
if hasarg(self.model.forward, Data.STEP):
kwargs[Data.STEP] = self.global_step

# Use the inferer if specified and in val/test/predict mode
if self.inferer and self.mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]:
return self.inferer(input, self.model, **kwargs)
else:
return self.model(input, **kwargs)

def _calculate_loss(self, input: Any, target: Any, pred: Any) -> Tensor | dict[str, Tensor] | None:
adapters = getattr(self.adapters, self.mode)
loss = None
Expand Down Expand Up @@ -238,26 +224,6 @@ def _prepare_output(
Data.EPOCH: self.current_epoch,
}

def on_train_start(self) -> None:
"""Called when the train begins."""
self.mode = Mode.TRAIN
self.batch_size = self.train_dataloader().batch_size

def on_validation_start(self) -> None:
"""Called when the validation loop begins."""
self.mode = Mode.VAL
self.batch_size = self.val_dataloader().batch_size

def on_test_start(self) -> None:
"""Called when the test begins."""
self.mode = Mode.TEST
self.batch_size = self.test_dataloader().batch_size

def on_predict_start(self) -> None:
"""Called when the prediction begins."""
self.mode = Mode.PREDICT
self.batch_size = self.predict_dataloader().batch_size

@property
def learning_rate(self) -> float:
"""
Expand All @@ -281,3 +247,37 @@ def learning_rate(self, value: float) -> None:
if len(self.optimizer.param_groups) > 1:
raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.")
self.optimizer.param_groups[0]["lr"] = value

def configure_optimizers(self) -> dict:
"""
Configures the optimizers and learning rate schedulers.
Returns:
dict: A dictionary containing the optimizer and scheduler.
"""
if self.optimizer is None:
raise ValueError("Please specify 'system.optimizer' in the config.")
if self.scheduler is None:
return {"optimizer": self.optimizer}
else:
return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler}

def on_train_start(self) -> None:
"""Called when the train begins."""
self.mode = Mode.TRAIN
self.batch_size = self.train_dataloader().batch_size

def on_validation_start(self) -> None:
"""Called when the validation loop begins."""
self.mode = Mode.VAL
self.batch_size = self.val_dataloader().batch_size

def on_test_start(self) -> None:
"""Called when the test begins."""
self.mode = Mode.TEST
self.batch_size = self.test_dataloader().batch_size

def on_predict_start(self) -> None:
"""Called when the prediction begins."""
self.mode = Mode.PREDICT
self.batch_size = self.predict_dataloader().batch_size

0 comments on commit 7d9f46f

Please sign in to comment.