-
Notifications
You must be signed in to change notification settings - Fork 380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor trainers #1541
Refactor trainers #1541
Conversation
@@ -76,7 +74,6 @@ def test_trainer( | |||
|
|||
# Instantiate model | |||
model = instantiate(conf.module) | |||
model.backbone = SegmentationTestModel(**conf.module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BYOLTask
has never had a backbone
attribute. This line of code didn't do anything, and the tests are still slower and require more memory than necessary.
|
||
.. versionadded:: 0.4 | ||
""" | ||
|
||
def config_task(self) -> None: | ||
"""Configures the task based on kwargs parameters passed to the constructor.""" | ||
backbone_pretrained = self.hyperparams.get("pretrained", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously this would use a pretrained backbone and download weights from the internet by default, which doesn't match the behavior of any other trainer.
@@ -4,9 +4,8 @@ module: | |||
model: "unet" | |||
backbone: "resnet18" | |||
weights: null | |||
learning_rate: 1e-3 | |||
learning_rate_schedule_patience: 6 | |||
verbose: false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verbose isn't an option
torchgeo/trainers/byol.py
Outdated
or None for random weights, or the path to a saved model state dict. | ||
in_channels: Number of input channels to model. | ||
lr: Learning rate for optimizer. | ||
weight_decay: Weight decay (L2 penalty). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many of these parameters were not previously documented.
@@ -274,6 +290,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: | |||
|
|||
def on_validation_epoch_end(self) -> None: | |||
"""Logs epoch level validation metrics.""" | |||
# TODO: why is this method necessary? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of our other trainers require special handling for metric logging. Is it only because of Lightning-AI/torchmetrics#1832 (comment)?
Regarding this point, for another domain library project I am also writing "trainers" that need a model, loss and metrics. When Isaac told me about hydra configs and instantiation I started using that at least for the model, loss and optimizer part. So you would have something like:
You could also provide defaults. But what I like is that you have some nice control via config files, so I could have:
So just via config you have control about optimizer and model architecture. Not sure if this is good practice but it has been convenient for running experiments and keeping track of configurations. |
You're in luck. Once this PR is merged I'll open a follow-up PR that switches everything to LightningCLI which supports exactly what you're describing without even changing the code: This all relies on jsonargparse which supports command-line, YAML, and JSON configuration. |
In #1195 we introduced a new structure for our trainers. This PR updates all existing trainers to match.
*args
or**kwargs
(prevents argument typos, adds default values and type hints, required for LightningCLI)typing.cast
(declare type when initialized)__init__
first (should be first thing in docs)configure_(losses|metrics|models)
methods (easier to override in subclass)This is a great time to bring #996 back up. At this point, our "trainers" are referred to as:
Task
suffixtorchgeo.trainers
module:
(this will go away in my next PR)I would love it if we could decide on a single naming scheme and be consistent...
Closes #1393