Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor trainers #1541

Merged
merged 19 commits into from
Sep 11, 2023
Merged

Refactor trainers #1541

merged 19 commits into from
Sep 11, 2023

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Sep 1, 2023

In #1195 we introduced a new structure for our trainers. This PR updates all existing trainers to match.

  • Add new base class (reduces code duplication)
  • No *args or **kwargs (prevents argument typos, adds default values and type hints, required for LightningCLI)
  • No typing.cast (declare type when initialized)
  • __init__ first (should be first thing in docs)
  • Added configure_(losses|metrics|models) methods (easier to override in subclass)

This is a great time to bring #996 back up. At this point, our "trainers" are referred to as:

  • tasks: they all have a Task suffix
  • trainers: they live in torchgeo.trainers
  • models: this is what Lightning refers to them as
  • modules: our YAML config file list them as module: (this will go away in my next PR)

I would love it if we could decide on a single naming scheme and be consistent...

Closes #1393

@adamjstewart adamjstewart added this to the 0.5.0 milestone Sep 1, 2023
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Sep 1, 2023
@github-actions github-actions bot added the testing Continuous integration testing label Sep 1, 2023
@adamjstewart adamjstewart added the backwards-incompatible Changes that are not backwards compatible label Sep 1, 2023
@adamjstewart adamjstewart marked this pull request as draft September 2, 2023 01:03
@@ -76,7 +74,6 @@ def test_trainer(

# Instantiate model
model = instantiate(conf.module)
model.backbone = SegmentationTestModel(**conf.module)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BYOLTask has never had a backbone attribute. This line of code didn't do anything, and the tests are still slower and require more memory than necessary.


.. versionadded:: 0.4
"""

def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
backbone_pretrained = self.hyperparams.get("pretrained", True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this would use a pretrained backbone and download weights from the internet by default, which doesn't match the behavior of any other trainer.

@@ -4,9 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verbose isn't an option

or None for random weights, or the path to a saved model state dict.
in_channels: Number of input channels to model.
lr: Learning rate for optimizer.
weight_decay: Weight decay (L2 penalty).
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of these parameters were not previously documented.

@@ -274,6 +290,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:

def on_validation_epoch_end(self) -> None:
"""Logs epoch level validation metrics."""
# TODO: why is this method necessary?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of our other trainers require special handling for metric logging. Is it only because of Lightning-AI/torchmetrics#1832 (comment)?

@adamjstewart adamjstewart marked this pull request as ready for review September 2, 2023 17:17
@nilsleh
Copy link
Collaborator

nilsleh commented Sep 5, 2023

  • Should we have a public configure_models, configure_losses, configure_metrics for every trainer to make it easier for people to override these? This would be similar to configure_optimizers.

Regarding this point, for another domain library project I am also writing "trainers" that need a model, loss and metrics. When Isaac told me about hydra configs and instantiation I started using that at least for the model, loss and optimizer part. So you would have something like:

class BaseModel(LightningModule):
    def __init__(
        self,
        model: nn.Module,
        optimizer: type[torch.optim.Optimizer],
        loss_fn: nn.Module,
    ) -> None:
        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loss_fn = loss_fn

    def configure_optimizers(self) -> dict[str, Any]:
        """Initialize the optimizer."""
        optimizer = self.optimizer(params=self.parameters())
        return optimizer

You could also provide defaults. But what I like is that you have some nice control via config files, so I could have:

method:
  _target_: BaseModel
  model:
    _target_: some.MLP # timm.create_model for timm models
    # all arguments to initiate the mlp for example
    n_outputs: 1
    n_hidden: [50]
  optimizer:
    _target_: torch.optim.Adam # can change optimizers here easily
    _partial_: true
    lr: 0.003
  loss_fn:
    _target_: torch.nn.MSELoss

So just via config you have control about optimizer and model architecture. Not sure if this is good practice but it has been convenient for running experiments and keeping track of configurations.

@adamjstewart
Copy link
Collaborator Author

You're in luck. Once this PR is merged I'll open a follow-up PR that switches everything to LightningCLI which supports exactly what you're describing without even changing the code:

This all relies on jsonargparse which supports command-line, YAML, and JSON configuration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor trainers
2 participants