Skip to content

Commit

Permalink
Refactor trainers (#1541)
Browse files Browse the repository at this point in the history
* Refactor trainers

* Update conf files

* Fix pydocstyle

* Add scheduler monitor

* Update conf files

* Fix BYOL backbone

* Remove broken configure_optimizers out type

* Fix type hints

* No casts

* Increase test coverage

* Better documentation of supported models

* Remove unimportant configuration

* Remove unimportant configuration

* Drop model_kwargs

* Docstring improvements

* Add base class for all torchgeo trainers

* Add configure_* methods for losses/metrics/models

* init must come first

* More type hints
  • Loading branch information
adamjstewart authored Sep 11, 2023
1 parent 5400840 commit 578aded
Show file tree
Hide file tree
Showing 92 changed files with 1,227 additions and 1,628 deletions.
4 changes: 2 additions & 2 deletions conf/bigearthnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ module:
_target_: torchgeo.trainers.MultiLabelClassificationTask
loss: "bce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
weights: null
in_channels: 14
num_classes: 19
Expand Down
4 changes: 2 additions & 2 deletions conf/chesapeake_cvpr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
in_channels: 4
num_classes: 7
num_filters: 256
Expand Down
4 changes: 2 additions & 2 deletions conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
lr: 1e-3
patience: 2

datamodule:
_target_: torchgeo.datamodules.COWCCountingDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/cyclone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
lr: 1e-3
patience: 2

datamodule:
_target_: torchgeo.datamodules.TropicalCycloneDataModule
Expand Down
5 changes: 2 additions & 3 deletions conf/deepglobelandcover.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
lr: 1e-3
patience: 6
in_channels: 3
num_classes: 7
num_filters: 1
Expand Down
4 changes: 2 additions & 2 deletions conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
in_channels: 6
num_classes: 2
ignore_index: 0
Expand Down
4 changes: 2 additions & 2 deletions conf/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ module:
_target_: torchgeo.trainers.ClassificationTask
loss: "ce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
weights: null
in_channels: 13
num_classes: 10
Expand Down
5 changes: 2 additions & 3 deletions conf/gid15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
lr: 1e-3
patience: 6
in_channels: 3
num_classes: 16
num_filters: 1
Expand Down
4 changes: 2 additions & 2 deletions conf/inria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
in_channels: 3
num_classes: 2
ignore_index: null
Expand Down
4 changes: 2 additions & 2 deletions conf/l7irish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 5
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.L7IrishDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/l8biome.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 5
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.L8BiomeDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
in_channels: 3
num_classes: 5
num_filters: 256
Expand Down
4 changes: 2 additions & 2 deletions conf/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
model: "deeplabv3+"
backbone: "resnet34"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 2
lr: 1e-3
patience: 2
in_channels: 4
num_classes: 14
num_filters: 64
Expand Down
5 changes: 2 additions & 3 deletions conf/nasa_marine_debris.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ module:
model: "faster-rcnn"
backbone: "resnet50"
num_classes: 2
learning_rate: 1.2e-4
learning_rate_schedule_patience: 6
verbose: false
lr: 1.2e-4
patience: 6

datamodule:
_target_: torchgeo.datamodules.NASAMarineDebrisDataModule
Expand Down
5 changes: 2 additions & 3 deletions conf/potsdam2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
lr: 1e-3
patience: 6
in_channels: 4
num_classes: 6
num_filters: 1
Expand Down
4 changes: 2 additions & 2 deletions conf/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ module:
_target_: torchgeo.trainers.ClassificationTask
loss: "ce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
weights: null
in_channels: 3
num_classes: 45
Expand Down
4 changes: 2 additions & 2 deletions conf/seco_100k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module:
in_channels: 12
backbone: "resnet18"
weights: True
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
optimizer: "Adam"

datamodule:
Expand Down
4 changes: 2 additions & 2 deletions conf/sen12ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 2
lr: 1e-3
patience: 2
in_channels: 15
num_classes: 11
ignore_index: null
Expand Down
4 changes: 2 additions & 2 deletions conf/so2sat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ module:
_target_: torchgeo.trainers.ClassificationTask
loss: "ce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
weights: null
in_channels: 18
num_classes: 17
Expand Down
4 changes: 2 additions & 2 deletions conf/spacenet1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: true
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
in_channels: 3
num_classes: 3
ignore_index: 0
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_etm_sr_cdl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 18
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_etm_sr_nlcd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 14
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_etm_toa_cdl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 18
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_etm_toa_nlcd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 14
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_oli_sr_cdl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 18
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_oli_sr_nlcd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 14
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 18
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 14
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_tm_toa_cdl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 18
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ssl4eo_benchmark_tm_toa_nlcd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module:
num_classes: 14
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6

datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
Expand Down
4 changes: 2 additions & 2 deletions conf/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module:
loss: "ce"
model: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
lr: 1e-3
patience: 6
in_channels: 3
num_classes: 21

Expand Down
5 changes: 2 additions & 3 deletions conf/vaihingen2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ module:
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
lr: 1e-3
patience: 6
in_channels: 3
num_classes: 7
num_filters: 1
Expand Down
3 changes: 0 additions & 3 deletions tests/conf/bigearthnet_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ module:
_target_: torchgeo.trainers.MultiLabelClassificationTask
loss: "bce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
in_channels: 14
num_classes: 19

Expand Down
3 changes: 0 additions & 3 deletions tests/conf/bigearthnet_s1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ module:
_target_: torchgeo.trainers.MultiLabelClassificationTask
loss: "bce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
in_channels: 2
num_classes: 19

Expand Down
3 changes: 0 additions & 3 deletions tests/conf/bigearthnet_s2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ module:
_target_: torchgeo.trainers.MultiLabelClassificationTask
loss: "bce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
in_channels: 12
num_classes: 19

Expand Down
3 changes: 0 additions & 3 deletions tests/conf/chesapeake_cvpr_5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ module:
loss: "ce"
model: "unet"
backbone: "resnet50"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 4
num_classes: 5
num_filters: 1
Expand Down
Loading

0 comments on commit 578aded

Please sign in to comment.