diff --git a/lighter/system.py b/lighter/system.py index e4429c9..04e27d6 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -87,46 +87,6 @@ def __init__( self.mode = None self.batch_size = 0 - def forward(self, input: Any) -> Any: # pylint: disable=arguments-differ - """ - Forward pass through the model. Supports multi-input models. - - Args: - input: The input data. - - Returns: - Any: The model's output. - """ - - # Keyword arguments to pass to the forward method - kwargs = {} - # Add `epoch` argument if forward accepts it - if hasarg(self.model.forward, Data.EPOCH): - kwargs[Data.EPOCH] = self.current_epoch - # Add `step` argument if forward accepts it - if hasarg(self.model.forward, Data.STEP): - kwargs[Data.STEP] = self.global_step - - # Use the inferer if specified and in val/test/predict mode - if self.inferer and self.mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]: - return self.inferer(input, self.model, **kwargs) - else: - return self.model(input, **kwargs) - - def configure_optimizers(self) -> dict: - """ - Configures the optimizers and learning rate schedulers. - - Returns: - dict: A dictionary containing the optimizer and scheduler. - """ - if self.optimizer is None: - raise ValueError("Please specify 'system.optimizer' in the config.") - if self.scheduler is None: - return {"optimizer": self.optimizer} - else: - return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler} - def _step(self, batch: dict, batch_idx: int) -> dict[str, Any] | Any: """ Performs a step in the specified mode, processing the batch and calculating loss and metrics. @@ -154,6 +114,32 @@ def _prepare_batch(self, batch: dict) -> tuple[Any, Any, Any]: input, target, identifier = adapters.batch(batch) return input, target, identifier + def forward(self, input: Any) -> Any: # pylint: disable=arguments-differ + """ + Forward pass through the model. Supports multi-input models. + + Args: + input: The input data. + + Returns: + Any: The model's output. + """ + + # Keyword arguments to pass to the forward method + kwargs = {} + # Add `epoch` argument if forward accepts it + if hasarg(self.model.forward, Data.EPOCH): + kwargs[Data.EPOCH] = self.current_epoch + # Add `step` argument if forward accepts it + if hasarg(self.model.forward, Data.STEP): + kwargs[Data.STEP] = self.global_step + + # Use the inferer if specified and in val/test/predict mode + if self.inferer and self.mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]: + return self.inferer(input, self.model, **kwargs) + else: + return self.model(input, **kwargs) + def _calculate_loss(self, input: Any, target: Any, pred: Any) -> Tensor | dict[str, Tensor] | None: adapters = getattr(self.adapters, self.mode) loss = None @@ -238,26 +224,6 @@ def _prepare_output( Data.EPOCH: self.current_epoch, } - def on_train_start(self) -> None: - """Called when the train begins.""" - self.mode = Mode.TRAIN - self.batch_size = self.train_dataloader().batch_size - - def on_validation_start(self) -> None: - """Called when the validation loop begins.""" - self.mode = Mode.VAL - self.batch_size = self.val_dataloader().batch_size - - def on_test_start(self) -> None: - """Called when the test begins.""" - self.mode = Mode.TEST - self.batch_size = self.test_dataloader().batch_size - - def on_predict_start(self) -> None: - """Called when the prediction begins.""" - self.mode = Mode.PREDICT - self.batch_size = self.predict_dataloader().batch_size - @property def learning_rate(self) -> float: """ @@ -281,3 +247,37 @@ def learning_rate(self, value: float) -> None: if len(self.optimizer.param_groups) > 1: raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.") self.optimizer.param_groups[0]["lr"] = value + + def configure_optimizers(self) -> dict: + """ + Configures the optimizers and learning rate schedulers. + + Returns: + dict: A dictionary containing the optimizer and scheduler. + """ + if self.optimizer is None: + raise ValueError("Please specify 'system.optimizer' in the config.") + if self.scheduler is None: + return {"optimizer": self.optimizer} + else: + return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler} + + def on_train_start(self) -> None: + """Called when the train begins.""" + self.mode = Mode.TRAIN + self.batch_size = self.train_dataloader().batch_size + + def on_validation_start(self) -> None: + """Called when the validation loop begins.""" + self.mode = Mode.VAL + self.batch_size = self.val_dataloader().batch_size + + def on_test_start(self) -> None: + """Called when the test begins.""" + self.mode = Mode.TEST + self.batch_size = self.test_dataloader().batch_size + + def on_predict_start(self) -> None: + """Called when the prediction begins.""" + self.mode = Mode.PREDICT + self.batch_size = self.predict_dataloader().batch_size