diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml index 3f159efa4b1..148f3104849 100644 --- a/conf/bigearthnet.yaml +++ b/conf/bigearthnet.yaml @@ -2,8 +2,8 @@ module: _target_: torchgeo.trainers.MultiLabelClassificationTask loss: "bce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 weights: null in_channels: 14 num_classes: 19 diff --git a/conf/chesapeake_cvpr.yaml b/conf/chesapeake_cvpr.yaml index 81af245a35a..03a27574f18 100644 --- a/conf/chesapeake_cvpr.yaml +++ b/conf/chesapeake_cvpr.yaml @@ -4,8 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 in_channels: 4 num_classes: 7 num_filters: 256 diff --git a/conf/cowc_counting.yaml b/conf/cowc_counting.yaml index 3b5d36779aa..d0166afd5b2 100644 --- a/conf/cowc_counting.yaml +++ b/conf/cowc_counting.yaml @@ -4,8 +4,8 @@ module: weights: null num_outputs: 1 in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 + lr: 1e-3 + patience: 2 datamodule: _target_: torchgeo.datamodules.COWCCountingDataModule diff --git a/conf/cyclone.yaml b/conf/cyclone.yaml index 2bb689ed4bf..98a25acbb86 100644 --- a/conf/cyclone.yaml +++ b/conf/cyclone.yaml @@ -4,8 +4,8 @@ module: weights: null num_outputs: 1 in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 + lr: 1e-3 + patience: 2 datamodule: _target_: torchgeo.datamodules.TropicalCycloneDataModule diff --git a/conf/deepglobelandcover.yaml b/conf/deepglobelandcover.yaml index 0260e0ac0f3..427fad68726 100644 --- a/conf/deepglobelandcover.yaml +++ b/conf/deepglobelandcover.yaml @@ -4,9 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false + lr: 1e-3 + patience: 6 in_channels: 3 num_classes: 7 num_filters: 1 diff --git a/conf/etci2021.yaml b/conf/etci2021.yaml index e993b8ac628..afdcb3f7c2f 100644 --- a/conf/etci2021.yaml +++ b/conf/etci2021.yaml @@ -4,8 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: true - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 in_channels: 6 num_classes: 2 ignore_index: 0 diff --git a/conf/eurosat.yaml b/conf/eurosat.yaml index b90f7823e01..b4d58e8e7b9 100644 --- a/conf/eurosat.yaml +++ b/conf/eurosat.yaml @@ -2,8 +2,8 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 weights: null in_channels: 13 num_classes: 10 diff --git a/conf/gid15.yaml b/conf/gid15.yaml index f46672da6ce..83578c27ebb 100644 --- a/conf/gid15.yaml +++ b/conf/gid15.yaml @@ -4,9 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false + lr: 1e-3 + patience: 6 in_channels: 3 num_classes: 16 num_filters: 1 diff --git a/conf/inria.yaml b/conf/inria.yaml index bbf73669a6a..1a07d81026b 100644 --- a/conf/inria.yaml +++ b/conf/inria.yaml @@ -4,8 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: true - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 in_channels: 3 num_classes: 2 ignore_index: null diff --git a/conf/l7irish.yaml b/conf/l7irish.yaml index 5f221aa9ef0..44eb15cd3db 100644 --- a/conf/l7irish.yaml +++ b/conf/l7irish.yaml @@ -7,8 +7,8 @@ module: num_classes: 5 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.L7IrishDataModule diff --git a/conf/l8biome.yaml b/conf/l8biome.yaml index b5bf7b552de..481990017c2 100644 --- a/conf/l8biome.yaml +++ b/conf/l8biome.yaml @@ -7,8 +7,8 @@ module: num_classes: 5 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.L8BiomeDataModule diff --git a/conf/landcoverai.yaml b/conf/landcoverai.yaml index f70667fe056..f3c1d92217e 100644 --- a/conf/landcoverai.yaml +++ b/conf/landcoverai.yaml @@ -4,8 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: true - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 in_channels: 3 num_classes: 5 num_filters: 256 diff --git a/conf/naipchesapeake.yaml b/conf/naipchesapeake.yaml index 94f6cafcab6..e1e3039291d 100644 --- a/conf/naipchesapeake.yaml +++ b/conf/naipchesapeake.yaml @@ -4,8 +4,8 @@ module: model: "deeplabv3+" backbone: "resnet34" weights: true - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 + lr: 1e-3 + patience: 2 in_channels: 4 num_classes: 14 num_filters: 64 diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml index d176e95c0e1..39aba0df5ba 100644 --- a/conf/nasa_marine_debris.yaml +++ b/conf/nasa_marine_debris.yaml @@ -3,9 +3,8 @@ module: model: "faster-rcnn" backbone: "resnet50" num_classes: 2 - learning_rate: 1.2e-4 - learning_rate_schedule_patience: 6 - verbose: false + lr: 1.2e-4 + patience: 6 datamodule: _target_: torchgeo.datamodules.NASAMarineDebrisDataModule diff --git a/conf/potsdam2d.yaml b/conf/potsdam2d.yaml index 747e99c2047..f069d84c497 100644 --- a/conf/potsdam2d.yaml +++ b/conf/potsdam2d.yaml @@ -4,9 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false + lr: 1e-3 + patience: 6 in_channels: 4 num_classes: 6 num_filters: 1 diff --git a/conf/resisc45.yaml b/conf/resisc45.yaml index fc22c9ca9e3..777f0c3749a 100644 --- a/conf/resisc45.yaml +++ b/conf/resisc45.yaml @@ -2,8 +2,8 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 weights: null in_channels: 3 num_classes: 45 diff --git a/conf/seco_100k.yaml b/conf/seco_100k.yaml index 41c6338bc02..7d8058ae2f1 100644 --- a/conf/seco_100k.yaml +++ b/conf/seco_100k.yaml @@ -3,8 +3,8 @@ module: in_channels: 12 backbone: "resnet18" weights: True - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 optimizer: "Adam" datamodule: diff --git a/conf/sen12ms.yaml b/conf/sen12ms.yaml index f1b4643c426..f9021450719 100644 --- a/conf/sen12ms.yaml +++ b/conf/sen12ms.yaml @@ -4,8 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 + lr: 1e-3 + patience: 2 in_channels: 15 num_classes: 11 ignore_index: null diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index 4a785a50e00..dea87f7adf4 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -2,8 +2,8 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 weights: null in_channels: 18 num_classes: 17 diff --git a/conf/spacenet1.yaml b/conf/spacenet1.yaml index 82955319a57..b3da966f2af 100644 --- a/conf/spacenet1.yaml +++ b/conf/spacenet1.yaml @@ -4,8 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: true - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 in_channels: 3 num_classes: 3 ignore_index: 0 diff --git a/conf/ssl4eo_benchmark_etm_sr_cdl.yaml b/conf/ssl4eo_benchmark_etm_sr_cdl.yaml index ed64e22b701..c5be32ae035 100644 --- a/conf/ssl4eo_benchmark_etm_sr_cdl.yaml +++ b/conf/ssl4eo_benchmark_etm_sr_cdl.yaml @@ -7,8 +7,8 @@ module: num_classes: 18 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml b/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml index ba6a6dd8dfc..066cbf0dbb7 100644 --- a/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml +++ b/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml @@ -7,8 +7,8 @@ module: num_classes: 14 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_etm_toa_cdl.yaml b/conf/ssl4eo_benchmark_etm_toa_cdl.yaml index da11cf9f42c..641b387de74 100644 --- a/conf/ssl4eo_benchmark_etm_toa_cdl.yaml +++ b/conf/ssl4eo_benchmark_etm_toa_cdl.yaml @@ -7,8 +7,8 @@ module: num_classes: 18 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml b/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml index 8e7e701416b..4e8ff97bcba 100644 --- a/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml +++ b/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml @@ -7,8 +7,8 @@ module: num_classes: 14 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_oli_sr_cdl.yaml b/conf/ssl4eo_benchmark_oli_sr_cdl.yaml index 292390cc25c..92f1f1be77c 100644 --- a/conf/ssl4eo_benchmark_oli_sr_cdl.yaml +++ b/conf/ssl4eo_benchmark_oli_sr_cdl.yaml @@ -7,8 +7,8 @@ module: num_classes: 18 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml b/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml index 982f8cd5b02..ef309f357ce 100644 --- a/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml +++ b/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml @@ -7,8 +7,8 @@ module: num_classes: 14 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml b/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml index 7ab684024b0..933f980a5d0 100644 --- a/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml +++ b/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml @@ -7,8 +7,8 @@ module: num_classes: 18 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml b/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml index 050801e6964..7059b818e7c 100644 --- a/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml +++ b/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml @@ -7,8 +7,8 @@ module: num_classes: 14 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_tm_toa_cdl.yaml b/conf/ssl4eo_benchmark_tm_toa_cdl.yaml index bc3ccdc4396..a5721b3814b 100644 --- a/conf/ssl4eo_benchmark_tm_toa_cdl.yaml +++ b/conf/ssl4eo_benchmark_tm_toa_cdl.yaml @@ -7,8 +7,8 @@ module: num_classes: 18 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml b/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml index d81cfaff6f5..03cba077955 100644 --- a/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml +++ b/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml @@ -7,8 +7,8 @@ module: num_classes: 14 loss: "ce" ignore_index: 0 - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 datamodule: _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule diff --git a/conf/ucmerced.yaml b/conf/ucmerced.yaml index 95fbe6fb87c..33d8c415c30 100644 --- a/conf/ucmerced.yaml +++ b/conf/ucmerced.yaml @@ -3,8 +3,8 @@ module: loss: "ce" model: "resnet18" weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 + lr: 1e-3 + patience: 6 in_channels: 3 num_classes: 21 diff --git a/conf/vaihingen2d.yaml b/conf/vaihingen2d.yaml index 4c5cf3b139a..6b53f0d32ab 100644 --- a/conf/vaihingen2d.yaml +++ b/conf/vaihingen2d.yaml @@ -4,9 +4,8 @@ module: model: "unet" backbone: "resnet18" weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false + lr: 1e-3 + patience: 6 in_channels: 3 num_classes: 7 num_filters: 1 diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index 3babdc7fd8b..64330c127af 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.MultiLabelClassificationTask loss: "bce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 14 num_classes: 19 diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index 8c07950cb5f..b3744b9313b 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.MultiLabelClassificationTask loss: "bce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 2 num_classes: 19 diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index 9408e20b633..46642e4c1dc 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.MultiLabelClassificationTask loss: "bce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 12 num_classes: 19 diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml index a3f8e08b48d..fda13ed9d23 100644 --- a/tests/conf/chesapeake_cvpr_5.yaml +++ b/tests/conf/chesapeake_cvpr_5.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet50" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 4 num_classes: 5 num_filters: 1 diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index 5b1f0669423..9e6ee726e0b 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -3,13 +3,10 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 4 num_classes: 7 num_filters: 1 ignore_index: null - weights: null datamodule: _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule diff --git a/tests/conf/chesapeake_cvpr_prior_byol.yaml b/tests/conf/chesapeake_cvpr_prior_byol.yaml index 3ccf939feff..fc10f3c3430 100644 --- a/tests/conf/chesapeake_cvpr_prior_byol.yaml +++ b/tests/conf/chesapeake_cvpr_prior_byol.yaml @@ -1,10 +1,7 @@ module: _target_: torchgeo.trainers.BYOLTask in_channels: 4 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null + model: "resnet18" datamodule: _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index f67b1b6a1be..64eada40334 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -1,11 +1,8 @@ module: _target_: torchgeo.trainers.RegressionTask model: resnet18 - weights: null num_outputs: 1 in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 loss: "mse" datamodule: diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index f7ecff850ba..3eb3a517111 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -1,11 +1,8 @@ module: _target_: torchgeo.trainers.RegressionTask model: "resnet18" - weights: null num_outputs: 1 in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 loss: "mse" datamodule: diff --git a/tests/conf/deepglobelandcover.yaml b/tests/conf/deepglobelandcover.yaml index 392fe3ce7b7..b871570b993 100644 --- a/tests/conf/deepglobelandcover.yaml +++ b/tests/conf/deepglobelandcover.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 3 num_classes: 7 num_filters: 1 diff --git a/tests/conf/etci2021.yaml b/tests/conf/etci2021.yaml index 9af839e92e3..5eada383d2c 100644 --- a/tests/conf/etci2021.yaml +++ b/tests/conf/etci2021.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 6 num_classes: 2 ignore_index: 0 diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index 7066f7f66ce..502ab0954d4 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 13 num_classes: 2 diff --git a/tests/conf/eurosat100.yaml b/tests/conf/eurosat100.yaml index 65e4be957f2..69628c2e195 100644 --- a/tests/conf/eurosat100.yaml +++ b/tests/conf/eurosat100.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 13 num_classes: 2 diff --git a/tests/conf/fire_risk.yaml b/tests/conf/fire_risk.yaml index 8971ee6839a..2982e112e25 100644 --- a/tests/conf/fire_risk.yaml +++ b/tests/conf/fire_risk.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 3 num_classes: 5 diff --git a/tests/conf/gid15.yaml b/tests/conf/gid15.yaml index c9af542d037..0c0027db002 100644 --- a/tests/conf/gid15.yaml +++ b/tests/conf/gid15.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 3 num_classes: 16 num_filters: 1 diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml index df4f4043fc4..90e26dc19bd 100644 --- a/tests/conf/inria.yaml +++ b/tests/conf/inria.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 3 num_classes: 2 ignore_index: null diff --git a/tests/conf/l7irish.yaml b/tests/conf/l7irish.yaml index d5147b0032d..2949d612c75 100644 --- a/tests/conf/l7irish.yaml +++ b/tests/conf/l7irish.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 9 num_classes: 5 num_filters: 1 diff --git a/tests/conf/l8biome.yaml b/tests/conf/l8biome.yaml index ae42f6efff3..ac85d14d934 100644 --- a/tests/conf/l8biome.yaml +++ b/tests/conf/l8biome.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 11 num_classes: 5 num_filters: 1 diff --git a/tests/conf/landcoverai.yaml b/tests/conf/landcoverai.yaml index 691d19bb9be..21b957e3014 100644 --- a/tests/conf/landcoverai.yaml +++ b/tests/conf/landcoverai.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 3 num_classes: 6 num_filters: 1 diff --git a/tests/conf/loveda.yaml b/tests/conf/loveda.yaml index 7a558ea2207..7d526b160ee 100644 --- a/tests/conf/loveda.yaml +++ b/tests/conf/loveda.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 3 num_classes: 8 num_filters: 1 diff --git a/tests/conf/naipchesapeake.yaml b/tests/conf/naipchesapeake.yaml index f9c0e4880fa..60d77fd3089 100644 --- a/tests/conf/naipchesapeake.yaml +++ b/tests/conf/naipchesapeake.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "deeplabv3+" backbone: "resnet34" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 in_channels: 4 num_classes: 14 num_filters: 1 diff --git a/tests/conf/nasa_marine_debris.yaml b/tests/conf/nasa_marine_debris.yaml index 7103560c5f3..36375e89c76 100644 --- a/tests/conf/nasa_marine_debris.yaml +++ b/tests/conf/nasa_marine_debris.yaml @@ -3,9 +3,6 @@ module: model: "faster-rcnn" backbone: "resnet18" num_classes: 2 - learning_rate: 1.2e-4 - learning_rate_schedule_patience: 6 - verbose: false datamodule: _target_: torchgeo.datamodules.NASAMarineDebrisDataModule diff --git a/tests/conf/potsdam2d.yaml b/tests/conf/potsdam2d.yaml index bd5f8f6c0ca..8f75b170b42 100644 --- a/tests/conf/potsdam2d.yaml +++ b/tests/conf/potsdam2d.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 4 num_classes: 6 num_filters: 1 diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index f8d1729572e..a8c5ea0f643 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 3 num_classes: 3 diff --git a/tests/conf/seco_byol_1.yaml b/tests/conf/seco_byol_1.yaml index 5f7e0b91b20..ae5e73a79c6 100644 --- a/tests/conf/seco_byol_1.yaml +++ b/tests/conf/seco_byol_1.yaml @@ -1,10 +1,7 @@ module: _target_: torchgeo.trainers.BYOLTask in_channels: 3 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null + model: "resnet18" datamodule: _target_: torchgeo.datamodules.SeasonalContrastS2DataModule diff --git a/tests/conf/seco_byol_2.yaml b/tests/conf/seco_byol_2.yaml index 07ff81c0132..90c17f1ed01 100644 --- a/tests/conf/seco_byol_2.yaml +++ b/tests/conf/seco_byol_2.yaml @@ -1,10 +1,7 @@ module: _target_: torchgeo.trainers.BYOLTask in_channels: 3 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null + model: "resnet18" datamodule: _target_: torchgeo.datamodules.SeasonalContrastS2DataModule diff --git a/tests/conf/sen12ms_all.yaml b/tests/conf/sen12ms_all.yaml index fe3d592a356..c4f53aee82b 100644 --- a/tests/conf/sen12ms_all.yaml +++ b/tests/conf/sen12ms_all.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 in_channels: 15 num_classes: 11 ignore_index: null diff --git a/tests/conf/sen12ms_s1.yaml b/tests/conf/sen12ms_s1.yaml index b0b9d553931..889d321950b 100644 --- a/tests/conf/sen12ms_s1.yaml +++ b/tests/conf/sen12ms_s1.yaml @@ -4,9 +4,6 @@ module: model: "fcn" num_filters: 1 backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 in_channels: 2 num_classes: 11 ignore_index: null diff --git a/tests/conf/sen12ms_s2_all.yaml b/tests/conf/sen12ms_s2_all.yaml index e80b74896e0..255077a3019 100644 --- a/tests/conf/sen12ms_s2_all.yaml +++ b/tests/conf/sen12ms_s2_all.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 in_channels: 13 num_classes: 11 ignore_index: null diff --git a/tests/conf/sen12ms_s2_reduced.yaml b/tests/conf/sen12ms_s2_reduced.yaml index 15758690e03..2b7965ead7d 100644 --- a/tests/conf/sen12ms_s2_reduced.yaml +++ b/tests/conf/sen12ms_s2_reduced.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 in_channels: 6 num_classes: 11 ignore_index: null diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml index 14dd1bcaabe..b1d5ec108e8 100644 --- a/tests/conf/skippd.yaml +++ b/tests/conf/skippd.yaml @@ -1,11 +1,8 @@ module: _target_: torchgeo.trainers.RegressionTask model: "resnet18" - weights: null num_outputs: 1 in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 loss: "mse" datamodule: diff --git a/tests/conf/so2sat_all.yaml b/tests/conf/so2sat_all.yaml index 22919afe697..7a3eb33d174 100644 --- a/tests/conf/so2sat_all.yaml +++ b/tests/conf/so2sat_all.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 18 num_classes: 17 diff --git a/tests/conf/so2sat_rgb.yaml b/tests/conf/so2sat_rgb.yaml index 75f7490ce22..fe7d462cbf7 100644 --- a/tests/conf/so2sat_rgb.yaml +++ b/tests/conf/so2sat_rgb.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 3 num_classes: 17 diff --git a/tests/conf/so2sat_s1.yaml b/tests/conf/so2sat_s1.yaml index c81e79742b8..8173ad5ccaa 100644 --- a/tests/conf/so2sat_s1.yaml +++ b/tests/conf/so2sat_s1.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "focal" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 8 num_classes: 17 diff --git a/tests/conf/so2sat_s2.yaml b/tests/conf/so2sat_s2.yaml index d7ba063efac..56a67ef5c26 100644 --- a/tests/conf/so2sat_s2.yaml +++ b/tests/conf/so2sat_s2.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "jaccard" model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null in_channels: 10 num_classes: 17 diff --git a/tests/conf/spacenet1.yaml b/tests/conf/spacenet1.yaml index dc88c2504d1..f39017e6cac 100644 --- a/tests/conf/spacenet1.yaml +++ b/tests/conf/spacenet1.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 3 num_classes: 3 num_filters: 1 diff --git a/tests/conf/ssl4eo_l_benchmark_cdl.yaml b/tests/conf/ssl4eo_l_benchmark_cdl.yaml index f44abedb3a7..a0aa2cb27e4 100644 --- a/tests/conf/ssl4eo_l_benchmark_cdl.yaml +++ b/tests/conf/ssl4eo_l_benchmark_cdl.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 7 num_classes: 134 num_filters: 1 diff --git a/tests/conf/ssl4eo_l_benchmark_nlcd.yaml b/tests/conf/ssl4eo_l_benchmark_nlcd.yaml index 6dd85d935b7..c6f9cf793af 100644 --- a/tests/conf/ssl4eo_l_benchmark_nlcd.yaml +++ b/tests/conf/ssl4eo_l_benchmark_nlcd.yaml @@ -3,9 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 6 num_classes: 17 num_filters: 1 diff --git a/tests/conf/ssl4eo_l_byol_1.yaml b/tests/conf/ssl4eo_l_byol_1.yaml index a8e3dc0cd79..f4236eb334e 100644 --- a/tests/conf/ssl4eo_l_byol_1.yaml +++ b/tests/conf/ssl4eo_l_byol_1.yaml @@ -1,10 +1,7 @@ module: _target_: torchgeo.trainers.BYOLTask in_channels: 7 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null + model: "resnet18" datamodule: _target_: torchgeo.datamodules.SSL4EOLDataModule diff --git a/tests/conf/ssl4eo_l_byol_2.yaml b/tests/conf/ssl4eo_l_byol_2.yaml index 2f1d87d83ff..fb0e2b3c02c 100644 --- a/tests/conf/ssl4eo_l_byol_2.yaml +++ b/tests/conf/ssl4eo_l_byol_2.yaml @@ -1,10 +1,7 @@ module: _target_: torchgeo.trainers.BYOLTask in_channels: 6 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null + model: "resnet18" datamodule: _target_: torchgeo.datamodules.SSL4EOLDataModule diff --git a/tests/conf/ssl4eo_s12_byol_1.yaml b/tests/conf/ssl4eo_s12_byol_1.yaml index 8d261d9de27..ab942c1e355 100644 --- a/tests/conf/ssl4eo_s12_byol_1.yaml +++ b/tests/conf/ssl4eo_s12_byol_1.yaml @@ -1,10 +1,7 @@ module: _target_: torchgeo.trainers.BYOLTask in_channels: 2 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null + model: "resnet18" datamodule: _target_: torchgeo.datamodules.SSL4EOS12DataModule diff --git a/tests/conf/ssl4eo_s12_byol_2.yaml b/tests/conf/ssl4eo_s12_byol_2.yaml index 0bf2164b0b5..e45f5d3bbc3 100644 --- a/tests/conf/ssl4eo_s12_byol_2.yaml +++ b/tests/conf/ssl4eo_s12_byol_2.yaml @@ -1,10 +1,7 @@ module: _target_: torchgeo.trainers.BYOLTask in_channels: 13 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null + model: "resnet18" datamodule: _target_: torchgeo.datamodules.SSL4EOS12DataModule diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml index 9b092aab674..4be0cf22a7c 100644 --- a/tests/conf/sustainbench_crop_yield.yaml +++ b/tests/conf/sustainbench_crop_yield.yaml @@ -1,11 +1,8 @@ module: _target_: torchgeo.trainers.RegressionTask model: "resnet18" - weights: null num_outputs: 1 in_channels: 9 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 loss: "mse" datamodule: diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 93e37db6059..0cb37c1c94f 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -2,9 +2,6 @@ module: _target_: torchgeo.trainers.ClassificationTask loss: "ce" model: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 in_channels: 3 num_classes: 2 diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml index ebdc8613ad2..0920772eaee 100644 --- a/tests/conf/vaihingen2d.yaml +++ b/tests/conf/vaihingen2d.yaml @@ -3,10 +3,6 @@ module: loss: "ce" model: "unet" backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false in_channels: 3 num_classes: 7 num_filters: 1 diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 83ad3099ae6..b4efbba6fe4 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -22,8 +22,6 @@ from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation -from .test_segmentation import SegmentationTestModel - def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: state_dict: dict[str, Any] = torch.load(url) @@ -32,8 +30,8 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: class TestBYOL: def test_custom_augment_fn(self) -> None: - backbone = resnet18() - layer = backbone.conv1 + model = resnet18() + layer = model.conv1 new_layer = nn.Conv2d( in_channels=4, out_channels=layer.out_channels, @@ -42,9 +40,9 @@ def test_custom_augment_fn(self) -> None: padding=layer.padding, bias=layer.bias, ).requires_grad_() - backbone.conv1 = new_layer + model.conv1 = new_layer augment_fn = SimCLRAugmentation((2, 2)) - BYOL(backbone, augment_fn=augment_fn) + BYOL(model, augment_fn=augment_fn) class TestBYOLTask: @@ -76,7 +74,6 @@ def test_trainer( # Instantiate model model = instantiate(conf.module) - model.backbone = SegmentationTestModel(**conf.module) # Instantiate trainer trainer = Trainer( @@ -87,16 +84,6 @@ def test_trainer( ) trainer.fit(model=model, datamodule=datamodule) - @pytest.fixture - def model_kwargs(self) -> dict[str, Any]: - return { - "backbone": "resnet18", - "in_channels": 13, - "loss": "ce", - "num_classes": 10, - "weights": None, - } - @pytest.fixture def weights(self) -> WeightsEnum: return ResNet18_Weights.SENTINEL2_ALL_MOCO @@ -117,41 +104,36 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: - model_kwargs["weights"] = checkpoint + def test_weight_file(self, checkpoint: str) -> None: with pytest.warns(UserWarning): - BYOLTask(**model_kwargs) + BYOLTask(model="resnet18", in_channels=13, weights=checkpoint) - def test_weight_enum( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = mocked_weights - BYOLTask(**model_kwargs) + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: + BYOLTask( + model=mocked_weights.meta["model"], + weights=mocked_weights, + in_channels=mocked_weights.meta["in_chans"], + ) - def test_weight_str( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = str(mocked_weights) - BYOLTask(**model_kwargs) + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: + BYOLTask( + model=mocked_weights.meta["model"], + weights=str(mocked_weights), + in_channels=mocked_weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_enum_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = weights - BYOLTask(**model_kwargs) + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + BYOLTask( + model=weights.meta["model"], + weights=weights, + in_channels=weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_str_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = str(weights) - BYOLTask(**model_kwargs) + def test_weight_str_download(self, weights: WeightsEnum) -> None: + BYOLTask( + model=weights.meta["model"], + weights=str(weights), + in_channels=weights.meta["in_chans"], + ) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index fb8b9c91d69..ae80e348c91 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -113,16 +113,6 @@ def test_trainer( except MisconfigurationException: pass - @pytest.fixture - def model_kwargs(self) -> dict[str, Any]: - return { - "model": "resnet18", - "in_channels": 13, - "loss": "ce", - "num_classes": 10, - "weights": None, - } - @pytest.fixture def weights(self) -> WeightsEnum: return ResNet18_Weights.SENTINEL2_ALL_MOCO @@ -143,61 +133,59 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: - model_kwargs["weights"] = checkpoint + def test_weight_file(self, checkpoint: str) -> None: with pytest.warns(UserWarning): - ClassificationTask(**model_kwargs) + ClassificationTask( + model="resnet18", weights=checkpoint, in_channels=13, num_classes=10 + ) - def test_weight_enum( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["model"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = mocked_weights + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): - ClassificationTask(**model_kwargs) - - def test_weight_str( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["model"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = str(mocked_weights) + ClassificationTask( + model=mocked_weights.meta["model"], + weights=mocked_weights, + in_channels=mocked_weights.meta["in_chans"], + num_classes=10, + ) + + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): - ClassificationTask(**model_kwargs) + ClassificationTask( + model=mocked_weights.meta["model"], + weights=str(mocked_weights), + in_channels=mocked_weights.meta["in_chans"], + num_classes=10, + ) @pytest.mark.slow - def test_weight_enum_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["model"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = weights - ClassificationTask(**model_kwargs) + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + ClassificationTask( + model=weights.meta["model"], + weights=weights, + in_channels=weights.meta["in_chans"], + num_classes=10, + ) @pytest.mark.slow - def test_weight_str_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["model"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = str(weights) - ClassificationTask(**model_kwargs) + def test_weight_str_download(self, weights: WeightsEnum) -> None: + ClassificationTask( + model=weights.meta["model"], + weights=str(weights), + in_channels=weights.meta["in_chans"], + num_classes=10, + ) - def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: - model_kwargs["loss"] = "invalid_loss" + def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - ClassificationTask(**model_kwargs) + ClassificationTask(model="resnet18", loss="invalid_loss") - def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool - ) -> None: + def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: monkeypatch.setattr(EuroSATDataModule, "plot", plot) datamodule = EuroSATDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) - model = ClassificationTask(**model_kwargs) + model = ClassificationTask(model="resnet18", in_channels=13, num_classes=10) trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -206,11 +194,11 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictClassificationDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) - model = ClassificationTask(**model_kwargs) + model = ClassificationTask(model="resnet18", in_channels=13, num_classes=10) trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -222,12 +210,8 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None @pytest.mark.parametrize( "model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"] ) - def test_freeze_backbone( - self, model_name: str, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["freeze_backbone"] = True - model_kwargs["model"] = model_name - model = ClassificationTask(**model_kwargs) + def test_freeze_backbone(self, model_name: str) -> None: + model = ClassificationTask(model=model_name, freeze_backbone=True) assert not all([param.requires_grad for param in model.model.parameters()]) assert all( [param.requires_grad for param in model.model.get_classifier().parameters()] @@ -267,30 +251,19 @@ def test_trainer( except MisconfigurationException: pass - @pytest.fixture - def model_kwargs(self) -> dict[str, Any]: - return { - "model": "resnet18", - "in_channels": 14, - "loss": "bce", - "num_classes": 19, - "weights": None, - } - - def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: - model_kwargs["loss"] = "invalid_loss" + def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - MultiLabelClassificationTask(**model_kwargs) + MultiLabelClassificationTask(model="resnet18", loss="invalid_loss") - def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool - ) -> None: + def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: monkeypatch.setattr(BigEarthNetDataModule, "plot", plot) datamodule = BigEarthNetDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) - model = MultiLabelClassificationTask(**model_kwargs) + model = MultiLabelClassificationTask( + model="resnet18", in_channels=14, num_classes=19, loss="bce" + ) trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -299,11 +272,13 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictMultiLabelClassificationDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) - model = MultiLabelClassificationTask(**model_kwargs) + model = MultiLabelClassificationTask( + model="resnet18", in_channels=14, num_classes=19, loss="bce" + ) trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 77ac3a3d768..23e08d87869 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -98,34 +98,25 @@ def test_trainer( except MisconfigurationException: pass - @pytest.fixture - def model_kwargs(self) -> dict[Any, Any]: - return {"model": "faster-rcnn", "backbone": "resnet18", "num_classes": 2} - - def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: - model_kwargs["model"] = "invalid_model" + def test_invalid_model(self) -> None: match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): - ObjectDetectionTask(**model_kwargs) + ObjectDetectionTask(model="invalid_model") - def test_invalid_backbone(self, model_kwargs: dict[Any, Any]) -> None: - model_kwargs["backbone"] = "invalid_backbone" + def test_invalid_backbone(self) -> None: match = "Backbone type 'invalid_backbone' is not valid." with pytest.raises(ValueError, match=match): - ObjectDetectionTask(**model_kwargs) + ObjectDetectionTask(backbone="invalid_backbone") - def test_non_pretrained_backbone(self, model_kwargs: dict[Any, Any]) -> None: - model_kwargs["pretrained"] = False - ObjectDetectionTask(**model_kwargs) + def test_pretrained_backbone(self) -> None: + ObjectDetectionTask(backbone="resnet18", weights=True) - def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool - ) -> None: + def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot) datamodule = NASAMarineDebrisDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) - model = ObjectDetectionTask(**model_kwargs) + model = ObjectDetectionTask(backbone="resnet18", num_classes=2) trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -134,11 +125,11 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictObjectDetectionDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) - model = ObjectDetectionTask(**model_kwargs) + model = ObjectDetectionTask(backbone="resnet18", num_classes=2) trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -148,10 +139,8 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None trainer.predict(model=model, datamodule=datamodule) @pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"]) - def test_freeze_backbone( - self, model_name: str, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["freeze_backbone"] = True - model_kwargs["model"] = model_name - model = ObjectDetectionTask(**model_kwargs) + def test_freeze_backbone(self, model_name: str) -> None: + model = ObjectDetectionTask( + model=model_name, backbone="resnet18", freeze_backbone=True + ) assert not all([param.requires_grad for param in model.model.parameters()]) diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py index ec5acf25bd9..18fb86bce36 100644 --- a/tests/trainers/test_moco.py +++ b/tests/trainers/test_moco.py @@ -105,49 +105,44 @@ def mocked_weights( return weights def test_weight_file(self, checkpoint: str) -> None: - model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint} match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - MoCoTask(**model_kwargs) + MoCoTask(model="resnet18", weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": mocked_weights.meta["model"], - "weights": mocked_weights, - "in_channels": mocked_weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - MoCoTask(**model_kwargs) + MoCoTask( + model=mocked_weights.meta["model"], + weights=mocked_weights, + in_channels=mocked_weights.meta["in_chans"], + ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": mocked_weights.meta["model"], - "weights": str(mocked_weights), - "in_channels": mocked_weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - MoCoTask(**model_kwargs) + MoCoTask( + model=mocked_weights.meta["model"], + weights=str(mocked_weights), + in_channels=mocked_weights.meta["in_chans"], + ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": weights.meta["model"], - "weights": weights, - "in_channels": weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - MoCoTask(**model_kwargs) + MoCoTask( + model=weights.meta["model"], + weights=weights, + in_channels=weights.meta["in_chans"], + ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": weights.meta["model"], - "weights": str(weights), - "in_channels": weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - MoCoTask(**model_kwargs) + MoCoTask( + model=weights.meta["model"], + weights=str(weights), + in_channels=weights.meta["in_chans"], + ) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 5bf3443e61b..3c64eca9b94 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -95,16 +95,6 @@ def test_trainer(self, name: str, fast_dev_run: bool) -> None: except MisconfigurationException: pass - @pytest.fixture - def model_kwargs(self) -> dict[str, Any]: - return { - "model": "resnet18", - "weights": None, - "num_outputs": 1, - "in_channels": 3, - "loss": "mse", - } - @pytest.fixture def weights(self) -> WeightsEnum: return ResNet18_Weights.SENTINEL2_ALL_MOCO @@ -125,55 +115,48 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: - model_kwargs["weights"] = checkpoint + def test_weight_file(self, checkpoint: str) -> None: with pytest.warns(UserWarning): - RegressionTask(**model_kwargs) + RegressionTask(model="resnet18", weights=checkpoint) - def test_weight_enum( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["model"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = mocked_weights + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): - RegressionTask(**model_kwargs) + RegressionTask( + model=mocked_weights.meta["model"], + weights=mocked_weights, + in_channels=mocked_weights.meta["in_chans"], + ) - def test_weight_str( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["model"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = str(mocked_weights) + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): - RegressionTask(**model_kwargs) + RegressionTask( + model=mocked_weights.meta["model"], + weights=str(mocked_weights), + in_channels=mocked_weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_enum_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["model"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = weights - RegressionTask(**model_kwargs) + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + RegressionTask( + model=weights.meta["model"], + weights=weights, + in_channels=weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_str_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["model"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = str(weights) - RegressionTask(**model_kwargs) + def test_weight_str_download(self, weights: WeightsEnum) -> None: + RegressionTask( + model=weights.meta["model"], + weights=str(weights), + in_channels=weights.meta["in_chans"], + ) - def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool - ) -> None: + def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot) datamodule = TropicalCycloneDataModule( root="tests/data/cyclone", batch_size=1, num_workers=0 ) - model = RegressionTask(**model_kwargs) + model = RegressionTask(model="resnet18") trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -182,11 +165,11 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None: + def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictRegressionDataModule( root="tests/data/cyclone", batch_size=1, num_workers=0 ) - model = RegressionTask(**model_kwargs) + model = RegressionTask(model="resnet18") trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -195,21 +178,16 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None ) trainer.predict(model=model, datamodule=datamodule) - def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None: - model_kwargs["loss"] = "invalid_loss" + def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - RegressionTask(**model_kwargs) + RegressionTask(model="resnet18", loss="invalid_loss") @pytest.mark.parametrize( "model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"] ) - def test_freeze_backbone( - self, model_name: str, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["freeze_backbone"] = True - model_kwargs["model"] = model_name - model = RegressionTask(**model_kwargs) + def test_freeze_backbone(self, model_name: str) -> None: + model = RegressionTask(model=model_name, freeze_backbone=True) assert not all([param.requires_grad for param in model.model.parameters()]) assert all( [param.requires_grad for param in model.model.get_classifier().parameters()] @@ -233,7 +211,6 @@ def test_trainer( loss: str, model_type: str, fast_dev_run: bool, - model_kwargs: dict[str, Any], ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) @@ -244,16 +221,11 @@ def test_trainer( # Instantiate model monkeypatch.setattr(smp, "Unet", create_model) monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) - model_kwargs["model"] = model_type - model_kwargs["loss"] = loss - - if model_type == "fcn": - model_kwargs["num_filters"] = 2 - model = PixelwiseRegressionTask(**model_kwargs) - model.model = PixelwiseRegressionTestModel( - in_channels=model_kwargs["in_channels"] + model = PixelwiseRegressionTask( + model=model_type, backbone="resnet18", loss=loss ) + model.model = PixelwiseRegressionTestModel() # Instantiate trainer trainer = Trainer( @@ -273,24 +245,10 @@ def test_trainer( except MisconfigurationException: pass - def test_invalid_model(self, model_kwargs: dict[str, Any]) -> None: - model_kwargs["model"] = "invalid_model" + def test_invalid_model(self) -> None: match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): - PixelwiseRegressionTask(**model_kwargs) - - @pytest.fixture - def model_kwargs(self) -> dict[str, Any]: - return { - "model": "unet", - "backbone": "resnet18", - "weights": None, - "num_outputs": 1, - "in_channels": 3, - "loss": "mse", - "learning_rate": 1e-3, - "learning_rate_schedule_patience": 6, - } + PixelwiseRegressionTask(model="invalid_model") @pytest.fixture def weights(self) -> WeightsEnum: @@ -312,55 +270,51 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: - model_kwargs["weights"] = checkpoint - PixelwiseRegressionTask(**model_kwargs) + def test_weight_file(self, checkpoint: str) -> None: + PixelwiseRegressionTask(model="unet", backbone="resnet18", weights=checkpoint) - def test_weight_enum( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = mocked_weights - PixelwiseRegressionTask(**model_kwargs) + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: + PixelwiseRegressionTask( + model="unet", + backbone=mocked_weights.meta["model"], + weights=mocked_weights, + in_channels=mocked_weights.meta["in_chans"], + ) - def test_weight_str( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = str(mocked_weights) - PixelwiseRegressionTask(**model_kwargs) + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: + PixelwiseRegressionTask( + model="unet", + backbone=mocked_weights.meta["model"], + weights=str(mocked_weights), + in_channels=mocked_weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_enum_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = weights - PixelwiseRegressionTask(**model_kwargs) + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + PixelwiseRegressionTask( + model="unet", + backbone=weights.meta["model"], + weights=weights, + in_channels=weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_str_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = str(weights) - PixelwiseRegressionTask(**model_kwargs) + def test_weight_str_download(self, weights: WeightsEnum) -> None: + PixelwiseRegressionTask( + model="unet", + backbone=weights.meta["model"], + weights=str(weights), + in_channels=weights.meta["in_chans"], + ) + @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) @pytest.mark.parametrize( "backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"] ) - @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) - def test_freeze_backbone( - self, backbone: str, model_name: str, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["freeze_backbone"] = True - model_kwargs["model"] = model_name - model_kwargs["backbone"] = backbone - model = PixelwiseRegressionTask(**model_kwargs) + def test_freeze_backbone(self, model_name: str, backbone: str) -> None: + model = PixelwiseRegressionTask( + model=model_name, backbone=backbone, freeze_backbone=True + ) assert all( [param.requires_grad is False for param in model.model.encoder.parameters()] ) @@ -373,12 +327,10 @@ def test_freeze_backbone( ) @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) - def test_freeze_decoder( - self, model_name: str, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["freeze_decoder"] = True - model_kwargs["model"] = model_name - model = PixelwiseRegressionTask(**model_kwargs) + def test_freeze_decoder(self, model_name: str) -> None: + model = PixelwiseRegressionTask( + model=model_name, backbone="resnet18", freeze_decoder=True + ) assert all( [param.requires_grad is False for param in model.model.decoder.parameters()] ) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index c29d37e24eb..920b5966265 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Any, cast -import numpy as np import pytest import segmentation_models_pytorch as smp import timm @@ -112,18 +111,6 @@ def test_trainer( except MisconfigurationException: pass - @pytest.fixture - def model_kwargs(self) -> dict[Any, Any]: - return { - "model": "unet", - "backbone": "resnet18", - "weights": None, - "in_channels": 3, - "num_classes": 6, - "loss": "ce", - "ignore_index": 0, - } - @pytest.fixture def weights(self) -> WeightsEnum: return ResNet18_Weights.SENTINEL2_ALL_MOCO @@ -144,78 +131,62 @@ def mocked_weights( monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) return weights - def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None: - model_kwargs["weights"] = checkpoint - SemanticSegmentationTask(**model_kwargs) + def test_weight_file(self, checkpoint: str) -> None: + SemanticSegmentationTask(backbone="resnet18", weights=checkpoint, num_classes=6) - def test_weight_enum( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = mocked_weights - SemanticSegmentationTask(**model_kwargs) + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: + SemanticSegmentationTask( + backbone=mocked_weights.meta["model"], + weights=mocked_weights, + in_channels=mocked_weights.meta["in_chans"], + ) - def test_weight_str( - self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = mocked_weights.meta["model"] - model_kwargs["in_channels"] = mocked_weights.meta["in_chans"] - model_kwargs["weights"] = str(mocked_weights) - SemanticSegmentationTask(**model_kwargs) + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: + SemanticSegmentationTask( + backbone=mocked_weights.meta["model"], + weights=str(mocked_weights), + in_channels=mocked_weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_enum_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = weights - SemanticSegmentationTask(**model_kwargs) + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + SemanticSegmentationTask( + backbone=weights.meta["model"], + weights=weights, + in_channels=weights.meta["in_chans"], + ) @pytest.mark.slow - def test_weight_str_download( - self, model_kwargs: dict[str, Any], weights: WeightsEnum - ) -> None: - model_kwargs["backbone"] = weights.meta["model"] - model_kwargs["in_channels"] = weights.meta["in_chans"] - model_kwargs["weights"] = str(weights) - SemanticSegmentationTask(**model_kwargs) + def test_weight_str_download(self, weights: WeightsEnum) -> None: + SemanticSegmentationTask( + backbone=weights.meta["model"], + weights=str(weights), + in_channels=weights.meta["in_chans"], + ) - def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None: - model_kwargs["model"] = "invalid_model" + def test_invalid_model(self) -> None: match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): - SemanticSegmentationTask(**model_kwargs) + SemanticSegmentationTask(model="invalid_model") - def test_invalid_loss(self, model_kwargs: dict[Any, Any]) -> None: - model_kwargs["loss"] = "invalid_loss" + def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - SemanticSegmentationTask(**model_kwargs) - - def test_invalid_ignoreindex(self, model_kwargs: dict[Any, Any]) -> None: - model_kwargs["ignore_index"] = "0" - match = "ignore_index must be an int or None" - with pytest.raises(ValueError, match=match): - SemanticSegmentationTask(**model_kwargs) + SemanticSegmentationTask(loss="invalid_loss") - def test_ignoreindex_with_jaccard(self, model_kwargs: dict[Any, Any]) -> None: - model_kwargs["loss"] = "jaccard" - model_kwargs["ignore_index"] = 0 + def test_ignoreindex_with_jaccard(self) -> None: match = "ignore_index has no effect on training when loss='jaccard'" with pytest.warns(UserWarning, match=match): - SemanticSegmentationTask(**model_kwargs) + SemanticSegmentationTask(loss="jaccard", ignore_index=0) - def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool - ) -> None: - model_kwargs["in_channels"] = 15 + def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: monkeypatch.setattr(SEN12MSDataModule, "plot", plot) datamodule = SEN12MSDataModule( root="tests/data/sen12ms", batch_size=1, num_workers=0 ) - model = SemanticSegmentationTask(**model_kwargs) + model = SemanticSegmentationTask( + backbone="resnet18", in_channels=15, num_classes=6 + ) trainer = Trainer( accelerator="cpu", fast_dev_run=fast_dev_run, @@ -224,17 +195,14 @@ def test_no_rgb( ) trainer.validate(model=model, datamodule=datamodule) + @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) @pytest.mark.parametrize( "backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"] ) - @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) - def test_freeze_backbone( - self, backbone: str, model_name: str, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["freeze_backbone"] = True - model_kwargs["model"] = model_name - model_kwargs["backbone"] = backbone - model = SemanticSegmentationTask(**model_kwargs) + def test_freeze_backbone(self, model_name: str, backbone: str) -> None: + model = SemanticSegmentationTask( + model=model_name, backbone=backbone, freeze_backbone=True + ) assert all( [param.requires_grad is False for param in model.model.encoder.parameters()] ) @@ -247,12 +215,8 @@ def test_freeze_backbone( ) @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) - def test_freeze_decoder( - self, model_name: str, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["freeze_decoder"] = True - model_kwargs["model"] = model_name - model = SemanticSegmentationTask(**model_kwargs) + def test_freeze_decoder(self, model_name: str) -> None: + model = SemanticSegmentationTask(model=model_name, freeze_decoder=True) assert all( [param.requires_grad is False for param in model.model.decoder.parameters()] ) @@ -263,23 +227,3 @@ def test_freeze_decoder( for param in model.model.segmentation_head.parameters() ] ) - - @pytest.mark.parametrize( - "class_weights", [torch.tensor([1, 2, 3]), np.array([1, 2, 3]), [1, 2, 3]] - ) - def test_classweights_valid( - self, class_weights: Any, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["class_weights"] = class_weights - sst = SemanticSegmentationTask(**model_kwargs) - assert isinstance(sst.loss.weight, torch.Tensor) - assert torch.equal(sst.loss.weight, torch.tensor([1.0, 2.0, 3.0])) - assert sst.loss.weight.dtype == torch.float32 - - @pytest.mark.parametrize("class_weights", [[], None]) - def test_classweights_empty( - self, class_weights: Any, model_kwargs: dict[Any, Any] - ) -> None: - model_kwargs["class_weights"] = class_weights - sst = SemanticSegmentationTask(**model_kwargs) - assert sst.loss.weight is None diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index 5fa8f15ea15..add431f5414 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -103,49 +103,44 @@ def mocked_weights( return weights def test_weight_file(self, checkpoint: str) -> None: - model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint} match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - SimCLRTask(**model_kwargs) + SimCLRTask(model="resnet18", weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": mocked_weights.meta["model"], - "weights": mocked_weights, - "in_channels": mocked_weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - SimCLRTask(**model_kwargs) + SimCLRTask( + model=mocked_weights.meta["model"], + weights=mocked_weights, + in_channels=mocked_weights.meta["in_chans"], + ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": mocked_weights.meta["model"], - "weights": str(mocked_weights), - "in_channels": mocked_weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - SimCLRTask(**model_kwargs) + SimCLRTask( + model=mocked_weights.meta["model"], + weights=str(mocked_weights), + in_channels=mocked_weights.meta["in_chans"], + ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": weights.meta["model"], - "weights": weights, - "in_channels": weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - SimCLRTask(**model_kwargs) + SimCLRTask( + model=weights.meta["model"], + weights=weights, + in_channels=weights.meta["in_chans"], + ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: - model_kwargs: dict[str, Any] = { - "model": weights.meta["model"], - "weights": str(weights), - "in_channels": weights.meta["in_chans"], - } match = "num classes .* != num classes in pretrained model" with pytest.warns(UserWarning, match=match): - SimCLRTask(**model_kwargs) + SimCLRTask( + model=weights.meta["model"], + weights=str(weights), + in_channels=weights.meta["in_chans"], + ) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index b39bc483b40..ec8d916a012 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -3,6 +3,7 @@ """TorchGeo trainers.""" +from .base import BaseTask from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask @@ -12,13 +13,17 @@ from .simclr import SimCLRTask __all__ = ( - "BYOLTask", + # Supervised "ClassificationTask", - "MoCoTask", "MultiLabelClassificationTask", "ObjectDetectionTask", "PixelwiseRegressionTask", "RegressionTask", "SemanticSegmentationTask", + # Self-supervised + "BYOLTask", + "MoCoTask", "SimCLRTask", + # Base classes + "BaseTask", ) diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py new file mode 100644 index 00000000000..4c63f7d9699 --- /dev/null +++ b/torchgeo/trainers/base.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Base classes for all :mod:`torchgeo` trainers.""" + +from abc import ABC, abstractmethod +from typing import Any + +from lightning.pytorch import LightningModule +from torch.optim import AdamW +from torch.optim.lr_scheduler import ReduceLROnPlateau + + +class BaseTask(LightningModule, ABC): + """Abstract base class for all TorchGeo trainers. + + .. versionadded:: 0.5 + """ + + #: Model to train + model: Any + + #: Performance metric to monitor in learning rate scheduler and callbacks + monitor = "val_loss" + + def __init__(self) -> None: + """Initialize a new BaseTask instance.""" + super().__init__() + self.save_hyperparameters() + self.configure_losses() + self.configure_metrics() + self.configure_models() + + def configure_losses(self) -> None: + """Initialize the loss criterion.""" + + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + + @abstractmethod + def configure_models(self) -> None: + """Initialize the model.""" + + def configure_optimizers(self) -> dict[str, Any]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + Optimizer and learning rate scheduler. + """ + optimizer = AdamW(self.parameters(), lr=self.hparams["lr"]) + scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams["patience"]) + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, + } + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Forward pass of the model. + + Args: + args: Arguments to pass to model. + kwargs: Keyword arguments to pass to model. + + Returns: + Output of the model. + """ + return self.model(*args, **kwargs) diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 00315f4028c..56b0290ca34 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -1,23 +1,22 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""BYOL tasks.""" +"""BYOL trainer for self-supervised learning (SSL).""" import os -from typing import Any, Optional, cast +from typing import Any, Optional, Union import timm import torch import torch.nn as nn import torch.nn.functional as F from kornia import augmentation as K -from lightning.pytorch import LightningModule -from torch import Tensor, optim -from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch import Tensor from torchvision.models._api import WeightsEnum from ..models import get_weight from . import utils +from .base import BaseTask def normalized_mse(x: Tensor, y: Tensor) -> Tensor: @@ -75,7 +74,8 @@ def forward(self, x: Tensor) -> Tensor: Returns: an augmented batch of imagery """ - return cast(Tensor, self.augmentation(x)) + z: Tensor = self.augmentation(x) + return z class MLP(nn.Module): @@ -108,7 +108,8 @@ def forward(self, x: Tensor) -> Tensor: Returns: embedded version of the input """ - return cast(Tensor, self.mlp(x)) + z: Tensor = self.mlp(x) + return z class BackboneWrapper(nn.Module): @@ -122,7 +123,7 @@ class BackboneWrapper(nn.Module): * The forward call returns the output of the projection head .. versionchanged 0.4: Name changed from *EncoderWrapper* to - *BackboneWrapper*. + *BackboneWrapper*. """ def __init__( @@ -270,7 +271,8 @@ def forward(self, x: Tensor) -> Tensor: Returns: output from the model """ - return cast(Tensor, self.predictor(self.backbone(x))) + z: Tensor = self.predictor(self.backbone(x)) + return z def update_target(self) -> None: """Method to update the "target" model weights.""" @@ -278,29 +280,58 @@ def update_target(self) -> None: pt.data = self.beta * pt.data + (1 - self.beta) * p.data -class BYOLTask(LightningModule): - """Class for pre-training any PyTorch model using BYOL. +class BYOLTask(BaseTask): + """BYOL: Bootstrap Your Own Latent. - Supports any available `Timm model - `_ - as an architecture choice. To see a list of available pretrained - models, you can do: + Reference implementation: - .. code-block:: python + * https://github.com/deepmind/deepmind-research/tree/master/byol - import timm - print(timm.list_models()) + If you use this trainer in your research, please cite the following paper: + + * https://arxiv.org/abs/2006.07733 """ - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - # Create model - in_channels = self.hyperparams["in_channels"] - weights = self.hyperparams["weights"] + monitor = "train_loss" + + def __init__( + self, + model: str = "resnet50", + weights: Optional[Union[WeightsEnum, str, bool]] = None, + in_channels: int = 3, + lr: float = 1e-3, + patience: int = 10, + ) -> None: + """Initialize a new BYOLTask instance. + + Args: + model: Name of the `timm + `__ model to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False + or None for random weights, or the path to a saved model state dict. + in_channels: Number of input channels to model. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + + .. versionchanged:: 0.4 + *backbone_name* was renamed to *backbone*. Changed backbone support from + torchvision.models to timm. + + .. versionchanged:: 0.5 + *backbone*, *learning_rate*, and *learning_rate_schedule_patience* were + renamed to *model*, *lr*, and *patience*. + """ + super().__init__() + + def configure_models(self) -> None: + """Initialize the model.""" + weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + in_channels: int = self.hparams["in_channels"] + + # Create backbone backbone = timm.create_model( - self.hyperparams["backbone"], - in_chans=in_channels, - pretrained=weights is True, + self.hparams["model"], in_chans=in_channels, pretrained=weights is True ) # Load weights @@ -315,79 +346,25 @@ def config_task(self) -> None: self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224)) - def __init__(self, **kwargs: Any) -> None: - """Initialize a LightningModule for pre-training a model with BYOL. - - Keyword Args: - in_channels: Number of input channels to model - backbone: Name of the timm model to use - weights: Either a weight enum, the string representation of a weight enum, - True for ImageNet weights, False or None for random weights, - or the path to a saved model state dict. - learning_rate: Learning rate for optimizer - learning_rate_schedule_patience: Patience for learning rate scheduler - - Raises: - ValueError: if kwargs arguments are invalid - - .. versionchanged:: 0.4 - The *backbone_name* parameter was renamed to *backbone*. Change backbone - support from torchvision.models to timm. - """ - super().__init__() - - # Creates `self.hparams` from kwargs - self.save_hyperparameters() - self.hyperparams = cast(dict[str, Any], self.hparams) - - self.config_task() - - def forward(self, *args: Any, **kwargs: Any) -> Any: - """Forward pass of the model. + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the training loss and additional metrics. Args: - x: tensor of data to run through the model - - Returns: - output from the model - """ - return self.model(*args, **kwargs) - - def configure_optimizers(self) -> dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - learning rate dictionary. - """ - optimizer_class = getattr(optim, self.hyperparams.get("optimizer", "Adam")) - lr = self.hyperparams.get("learning_rate", 1e-4) - weight_decay = self.hyperparams.get("weight_decay", 1e-6) - optimizer = optimizer_class(self.parameters(), lr=lr, weight_decay=weight_decay) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, - patience=self.hyperparams["learning_rate_schedule_patience"], - ), - "monitor": "train_loss", - }, - } - - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. + The loss tensor. - Args: - batch: the output of your DataLoader - - Returns: - training loss + Raises: + AssertionError: If channel dimensions are incorrect. """ - batch = args[0] x = batch["image"] - in_channels = self.hyperparams["in_channels"] + in_channels = self.hparams["in_channels"] assert x.size(1) == in_channels or x.size(1) == 2 * in_channels if x.size(1) == in_channels: @@ -409,16 +386,18 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) - self.log("train_loss", loss, on_step=True, on_epoch=False) + self.log("train_loss", loss) self.model.update_target() return loss - def validation_step(self, *args: Any, **kwargs: Any) -> None: + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: """No-op, does nothing.""" - def test_step(self, *args: Any, **kwargs: Any) -> None: + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" - def predict_step(self, *args: Any, **kwargs: Any) -> None: + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 29f96f2f9fd..bc0d0fd37cd 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -1,19 +1,17 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Classification tasks.""" +"""Trainers for image classification.""" import os -from typing import Any, cast +from typing import Any, Optional, Union import matplotlib.pyplot as plt import timm import torch import torch.nn as nn -from lightning.pytorch import LightningModule from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss from torch import Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MetricCollection from torchmetrics.classification import ( MulticlassAccuracy, @@ -27,172 +25,159 @@ from ..datasets import unbind_samples from ..models import get_weight from . import utils +from .base import BaseTask -class ClassificationTask(LightningModule): - """LightningModule for image classification. +class ClassificationTask(BaseTask): + """Image classification.""" - Supports any available `Timm model - `_ - as an architecture choice. To see a list of available - models, you can do: + def __init__( + self, + model: str = "resnet50", + weights: Optional[Union[WeightsEnum, str, bool]] = None, + in_channels: int = 3, + num_classes: int = 1000, + loss: str = "ce", + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + ) -> None: + """Initialize a new ClassificationTask instance. - .. code-block:: python - - import timm - print(timm.list_models()) - """ - - def config_model(self) -> None: - """Configures the model based on kwargs parameters passed to the constructor.""" - # Create model - weights = self.hyperparams["weights"] - self.model = timm.create_model( - self.hyperparams["model"], - num_classes=self.hyperparams["num_classes"], - in_chans=self.hyperparams["in_channels"], - pretrained=weights is True, - ) - - # Load weights - if weights and weights is not True: - if isinstance(weights, WeightsEnum): - state_dict = weights.get_state_dict(progress=True) - elif os.path.exists(weights): - _, state_dict = utils.extract_backbone(weights) - else: - state_dict = get_weight(weights).get_state_dict(progress=True) - self.model = utils.load_state_dict(self.model, state_dict) - - # Freeze backbone and unfreeze classifier head - if self.hyperparams.get("freeze_backbone", False): - for param in self.model.parameters(): - param.requires_grad = False - for param in self.model.get_classifier().parameters(): - param.requires_grad = True - - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - self.config_model() - - if self.hyperparams["loss"] == "ce": - self.loss: nn.Module = nn.CrossEntropyLoss() - elif self.hyperparams["loss"] == "jaccard": - self.loss = JaccardLoss(mode="multiclass") - elif self.hyperparams["loss"] == "focal": - self.loss = FocalLoss(mode="multiclass", normalized=True) - else: - raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - model: Name of the classification model use - loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal' - weights: Either a weight enum, the string representation of a weight enum, - True for ImageNet weights, False or None for random weights, - or the path to a saved model state dict. - num_classes: Number of prediction classes - in_channels: Number of input channels to model - learning_rate: Learning rate for optimizer - learning_rate_schedule_patience: Patience for learning rate scheduler + Args: + model: Name of the `timm + `__ model to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False + or None for random weights, or the path to a saved model state dict. + in_channels: Number of input channels to model. + num_classes: Number of prediction classes. + loss: One of 'ce', 'bce', 'jaccard', or 'focal'. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. freeze_backbone: Freeze the backbone network to linear probe - the classifier head + the classifier head. .. versionchanged:: 0.4 - The *classification_model* parameter was renamed to *model*. + *classification_model* was renamed to *model*. .. versionadded:: 0.5 The *freeze_backbone* parameter. + + .. versionchanged:: 0.5 + *learning_rate* and *learning_rate_schedule_patience* were renamed to + *lr* and *patience*. """ super().__init__() - # Creates `self.hparams` from kwargs - self.save_hyperparameters() - self.hyperparams = cast(dict[str, Any], self.hparams) + def configure_losses(self) -> None: + """Initialize the loss criterion. - self.config_task() + Raises: + ValueError: If *loss* is invalid. + """ + loss: str = self.hparams["loss"] + if loss == "ce": + self.criterion: nn.Module = nn.CrossEntropyLoss() + elif loss == "bce": + self.criterion = nn.BCEWithLogitsLoss() + elif loss == "jaccard": + self.criterion = JaccardLoss(mode="multiclass") + elif loss == "focal": + self.criterion = FocalLoss(mode="multiclass", normalized=True) + else: + raise ValueError(f"Loss type '{loss}' is not valid.") - self.train_metrics = MetricCollection( + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + metrics = MetricCollection( { "OverallAccuracy": MulticlassAccuracy( - num_classes=self.hyperparams["num_classes"], average="micro" + num_classes=self.hparams["num_classes"], average="micro" ), "AverageAccuracy": MulticlassAccuracy( - num_classes=self.hyperparams["num_classes"], average="macro" + num_classes=self.hparams["num_classes"], average="macro" ), "JaccardIndex": MulticlassJaccardIndex( - num_classes=self.hyperparams["num_classes"] + num_classes=self.hparams["num_classes"] ), "F1Score": MulticlassFBetaScore( - num_classes=self.hyperparams["num_classes"], - beta=1.0, - average="micro", + num_classes=self.hparams["num_classes"], beta=1.0, average="micro" ), - }, - prefix="train_", + } ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") + self.train_metrics = metrics.clone(prefix="train_") + self.val_metrics = metrics.clone(prefix="val_") + self.test_metrics = metrics.clone(prefix="test_") - def forward(self, *args: Any, **kwargs: Any) -> Any: - """Forward pass of the model. + def configure_models(self) -> None: + """Initialize the model.""" + weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] - Args: - x: input image + # Create model + self.model = timm.create_model( + self.hparams["model"], + num_classes=self.hparams["num_classes"], + in_chans=self.hparams["in_channels"], + pretrained=weights is True, + ) - Returns: - prediction - """ - return self.model(*args, **kwargs) + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.model = utils.load_state_dict(self.model, state_dict) - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. + # Freeze backbone and unfreeze classifier head + if self.hparams["freeze_backbone"]: + for param in self.model.parameters(): + param.requires_grad = False + for param in self.model.get_classifier().parameters(): + param.requires_grad = True + + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the training loss and additional metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - training loss + The loss tensor. """ - batch = args[0] x = batch["image"] y = batch["label"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) + loss: Tensor = self.criterion(y_hat, y) + self.log("train_loss", loss) self.train_metrics(y_hat_hard, y) - return cast(Tensor, loss) + return loss - def on_train_epoch_end(self) -> None: - """Logs epoch-level training metrics.""" - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - - def validation_step(self, *args: Any, **kwargs: Any) -> None: - """Compute validation loss and log example predictions. + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Compute the validation loss and additional metrics. Args: - batch: the output of your DataLoader - batch_idx: the index of this batch + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] - batch_idx = args[1] x = batch["image"] y = batch["label"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) + loss = self.criterion(y_hat, y) + self.log("val_loss", loss) self.val_metrics(y_hat_hard, y) if ( @@ -217,162 +202,100 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: except ValueError: pass - def on_validation_epoch_end(self) -> None: - """Logs epoch level validation metrics.""" - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step(self, *args: Any, **kwargs: Any) -> None: - """Compute test loss. + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Compute the test loss and additional metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] x = batch["image"] y = batch["label"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) + loss = self.criterion(y_hat, y) + self.log("test_loss", loss) self.test_metrics(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - """Logs epoch level test metrics.""" - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the predictions. + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted class probabilities. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - predicted softmax probabilities + Output predicted probabilities. """ - batch = args[0] x = batch["image"] y_hat: Tensor = self(x).softmax(dim=-1) return y_hat - def configure_optimizers(self) -> dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - learning rate dictionary - """ - optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.hyperparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, - patience=self.hyperparams["learning_rate_schedule_patience"], - ), - "monitor": "val_loss", - }, - } - class MultiLabelClassificationTask(ClassificationTask): - """LightningModule for multi-label image classification.""" - - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - self.config_model() + """Multi-label image classification.""" - if self.hyperparams["loss"] == "bce": - self.loss = nn.BCEWithLogitsLoss() - else: - raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - model: Name of the classification model use - loss: Name of the loss function, currently only supports 'bce' - weights: Either "random" or 'imagenet' - num_classes: Number of prediction classes - in_channels: Number of input channels to model - learning_rate: Learning rate for optimizer - learning_rate_schedule_patience: Patience for learning rate scheduler - freeze_backbone: Freeze the backbone network to linear probe - the classifier head - - .. versionchanged:: 0.4 - The *classification_model* parameter was renamed to *model*. - - .. versionadded:: 0.5 - The *freeze_backbone* parameter. - """ - super().__init__(**kwargs) - - self.train_metrics = MetricCollection( + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + metrics = MetricCollection( { "OverallAccuracy": MultilabelAccuracy( - num_labels=self.hyperparams["num_classes"], average="micro" + num_labels=self.hparams["num_classes"], average="micro" ), "AverageAccuracy": MultilabelAccuracy( - num_labels=self.hyperparams["num_classes"], average="macro" + num_labels=self.hparams["num_classes"], average="macro" ), "F1Score": MultilabelFBetaScore( - num_labels=self.hyperparams["num_classes"], - beta=1.0, - average="micro", + num_labels=self.hparams["num_classes"], beta=1.0, average="micro" ), - }, - prefix="train_", + } ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") + self.train_metrics = metrics.clone(prefix="train_") + self.val_metrics = metrics.clone(prefix="val_") + self.test_metrics = metrics.clone(prefix="test_") - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the training loss and additional metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - training loss + The loss tensor. """ - batch = args[0] x = batch["image"] y = batch["label"] y_hat = self(x) y_hat_hard = torch.sigmoid(y_hat) - - loss = self.loss(y_hat, y.to(torch.float)) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) + loss: Tensor = self.criterion(y_hat, y.to(torch.float)) + self.log("train_loss", loss) self.train_metrics(y_hat_hard, y) - return cast(Tensor, loss) + return loss - def validation_step(self, *args: Any, **kwargs: Any) -> None: - """Compute validation loss and log example predictions. + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Compute the validation loss and additional metrics. Args: - batch: the output of your DataLoader - batch_idx: the index of this batch + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] - batch_idx = args[1] x = batch["image"] y = batch["label"] y_hat = self(x) y_hat_hard = torch.sigmoid(y_hat) - - loss = self.loss(y_hat, y.to(torch.float)) - + loss = self.criterion(y_hat, y.to(torch.float)) self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) @@ -397,33 +320,35 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: except ValueError: pass - def test_step(self, *args: Any, **kwargs: Any) -> None: - """Compute test loss. + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Compute the test loss and additional metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] x = batch["image"] y = batch["label"] y_hat = self(x) y_hat_hard = torch.sigmoid(y_hat) - - loss = self.loss(y_hat, y.to(torch.float)) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) + loss = self.criterion(y_hat, y.to(torch.float)) + self.log("test_loss", loss) self.test_metrics(y_hat_hard, y) - def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the predictions. + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted class probabilities. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + Returns: - predicted sigmoid probabilities + Output predicted probabilities. """ - batch = args[0] x = batch["image"] y_hat = torch.sigmoid(self(x)) return y_hat diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index ed58bf945a2..27e2bcaeedf 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -1,17 +1,15 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Detection tasks.""" +"""Trainers for object detection.""" from functools import partial -from typing import Any, cast +from typing import Any, Optional import matplotlib.pyplot as plt import torch import torchvision.models.detection -from lightning.pytorch import LightningModule from torch import Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MetricCollection from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models import resnet as R @@ -21,6 +19,7 @@ from torchvision.ops import MultiScaleRoIAlign, feature_pyramid_network, misc from ..datasets.utils import unbind_samples +from .base import BaseTask BACKBONE_LAT_DIM_MAP = { "resnet18": 512, @@ -47,48 +46,87 @@ } -class ObjectDetectionTask(LightningModule): - """LightningModule for object detection of images. +class ObjectDetectionTask(BaseTask): + """Object detection. - Currently, supports Faster R-CNN, FCOS, and RetinaNet models from - `torchvision - `_ with - one of the following *backbone* arguments: + .. versionadded:: 0.4 + """ - .. code-block:: python + monitor = "val_map" + + def __init__( + self, + model: str = "faster-rcnn", + backbone: str = "resnet50", + weights: Optional[bool] = None, + in_channels: int = 3, + num_classes: int = 1000, + trainable_layers: int = 3, + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + ) -> None: + """Initialize a new ObjectDetectionTask instance. - ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', - 'resnext50_32x4d','resnext101_32x8d', 'wide_resnet50_2', - 'wide_resnet101_2'] + Args: + model: Name of the `torchvision + `__ + model to use. One of 'faster-rcnn', 'fcos', or 'retinanet'. + backbone: Name of the `torchvision + `__ + backbone to use. One of 'resnet18', 'resnet34', 'resnet50', + 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', or 'wide_resnet101_2'. + weights: Initial model weights. True for ImageNet weights, False or None + for random weights. + in_channels: Number of input channels to model. + num_classes: Number of prediction classes. + trainable_layers: Number of trainable layers. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + freeze_backbone: Freeze the backbone network to fine-tune the detection + head. - .. versionadded:: 0.4 - """ + .. versionchanged:: 0.4 + *detection_model* was renamed to *model*. - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - backbone_pretrained = self.hyperparams.get("pretrained", True) + .. versionadded:: 0.5 + The *freeze_backbone* parameter. + + .. versionchanged:: 0.5 + *pretrained*, *learning_rate*, and *learning_rate_schedule_patience* were + renamed to *weights*, *lr*, and *patience*. + """ + super().__init__() + + def configure_models(self) -> None: + """Initialize the model. - if self.hyperparams["backbone"] in BACKBONE_LAT_DIM_MAP: + Raises: + ValueError: If *model* or *backbone* are invalid. + """ + backbone: str = self.hparams["backbone"] + model: str = self.hparams["model"] + weights: Optional[bool] = self.hparams["weights"] + num_classes: int = self.hparams["num_classes"] + freeze_backbone: bool = self.hparams["freeze_backbone"] + + if backbone in BACKBONE_LAT_DIM_MAP: kwargs = { - "backbone_name": self.hyperparams["backbone"], - "trainable_layers": self.hyperparams.get("trainable_layers", 3), + "backbone_name": backbone, + "trainable_layers": self.hparams["trainable_layers"], } - if backbone_pretrained: - kwargs["weights"] = BACKBONE_WEIGHT_MAP[self.hyperparams["backbone"]] + if weights: + kwargs["weights"] = BACKBONE_WEIGHT_MAP[backbone] else: kwargs["weights"] = None - latent_dim = BACKBONE_LAT_DIM_MAP[self.hyperparams["backbone"]] + latent_dim = BACKBONE_LAT_DIM_MAP[backbone] else: - raise ValueError( - f"Backbone type '{self.hyperparams['backbone']}' is not valid." - ) + raise ValueError(f"Backbone type '{backbone}' is not valid.") - num_classes = self.hyperparams["num_classes"] - - if self.hyperparams["model"] == "faster-rcnn": - backbone = resnet_fpn_backbone(**kwargs) + if model == "faster-rcnn": + model_backbone = resnet_fpn_backbone(**kwargs) anchor_generator = AnchorGenerator( sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0)) ) @@ -97,40 +135,40 @@ def config_task(self) -> None: featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2 ) - if self.hyperparams.get("freeze_backbone", False): - for param in backbone.parameters(): + if freeze_backbone: + for param in model_backbone.parameters(): param.requires_grad = False self.model = torchvision.models.detection.FasterRCNN( - backbone, + model_backbone, num_classes, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, ) - elif self.hyperparams["model"] == "fcos": + elif model == "fcos": kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7(256, 256) kwargs["norm_layer"] = ( - misc.FrozenBatchNorm2d if kwargs["weights"] else torch.nn.BatchNorm2d + misc.FrozenBatchNorm2d if weights else torch.nn.BatchNorm2d ) - backbone = resnet_fpn_backbone(**kwargs) + model_backbone = resnet_fpn_backbone(**kwargs) anchor_generator = AnchorGenerator( sizes=((8,), (16,), (32,), (64,), (128,), (256,)), aspect_ratios=((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)), ) - if self.hyperparams.get("freeze_backbone", False): - for param in backbone.parameters(): + if freeze_backbone: + for param in model_backbone.parameters(): param.requires_grad = False self.model = torchvision.models.detection.FCOS( - backbone, num_classes, anchor_generator=anchor_generator + model_backbone, num_classes, anchor_generator=anchor_generator ) - elif self.hyperparams["model"] == "retinanet": + elif model == "retinanet": kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7( latent_dim, 256 ) - backbone = resnet_fpn_backbone(**kwargs) + model_backbone = resnet_fpn_backbone(**kwargs) anchor_sizes = ( (16, 20, 25), @@ -144,75 +182,44 @@ def config_task(self) -> None: anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) head = RetinaNetHead( - backbone.out_channels, + model_backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes, norm_layer=partial(torch.nn.GroupNorm, 32), ) - if self.hyperparams.get("freeze_backbone", False): - for param in backbone.parameters(): + if freeze_backbone: + for param in model_backbone.parameters(): param.requires_grad = False self.model = torchvision.models.detection.RetinaNet( - backbone, num_classes, anchor_generator=anchor_generator, head=head + model_backbone, + num_classes, + anchor_generator=anchor_generator, + head=head, ) else: - raise ValueError(f"Model type '{self.hyperparams['model']}' is not valid.") - - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - model: Name of the detection model type to use - backbone: Name of the model backbone to use - in_channels: Number of channels in input image - num_classes: Number of semantic classes to predict - learning_rate: Learning rate for optimizer - learning_rate_schedule_patience: Patience for learning rate scheduler - freeze_backbone: Freeze the backbone network to fine-tune the detection head - - Raises: - ValueError: if kwargs arguments are invalid - - .. versionchanged:: 0.4 - The *detection_model* parameter was renamed to *model*. - - .. versionadded:: 0.5 - The *freeze_backbone* parameter. - """ - super().__init__() - # Creates `self.hparams` from kwargs - self.save_hyperparameters() - self.hyperparams = cast(dict[str, Any], self.hparams) - - self.config_task() + raise ValueError(f"Model type '{model}' is not valid.") + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" metrics = MetricCollection([MeanAveragePrecision()]) self.val_metrics = metrics.clone(prefix="val_") self.test_metrics = metrics.clone(prefix="test_") - def forward(self, *args: Any, **kwargs: Any) -> Any: - """Forward pass of the model. - - Args: - x: tensor of data to run through the model - - Returns: - output from the model - """ - return self.model(*args, **kwargs) - - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the training loss. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - training loss + The loss tensor. """ - batch = args[0] x = batch["image"] batch_size = x.shape[0] y = [ @@ -220,21 +227,20 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: for i in range(batch_size) ] loss_dict = self(x, y) - train_loss = sum(loss_dict.values()) - + train_loss: Tensor = sum(loss_dict.values()) self.log_dict(loss_dict) + return train_loss - return cast(Tensor, train_loss) - - def validation_step(self, *args: Any, **kwargs: Any) -> None: - """Compute validation loss and log example predictions. + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Compute the validation metrics. Args: - batch: the output of your DataLoader - batch_idx: the index of this batch + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] - batch_idx = args[1] x = batch["image"] batch_size = x.shape[0] y = [ @@ -242,8 +248,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: for i in range(batch_size) ] y_hat = self(x) - - self.val_metrics.update(y_hat, y) + self.val_metrics(y_hat, y) if ( batch_idx < 10 @@ -273,7 +278,8 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: pass def on_validation_epoch_end(self) -> None: - """Logs epoch level validation metrics.""" + """Log epoch level validation metrics.""" + # TODO: why is this method necessary? metrics = self.val_metrics.compute() # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 @@ -282,13 +288,14 @@ def on_validation_epoch_end(self) -> None: self.log_dict(metrics) self.val_metrics.reset() - def test_step(self, *args: Any, **kwargs: Any) -> None: - """Compute test MAP. + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Compute the test metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] x = batch["image"] batch_size = x.shape[0] y = [ @@ -296,50 +303,21 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: for i in range(batch_size) ] y_hat = self(x) - self.test_metrics.update(y_hat, y) - def on_test_epoch_end(self) -> None: - """Logs epoch level test metrics.""" - metrics = self.test_metrics.compute() - - # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 - metrics.pop("test_classes", None) - - self.log_dict(metrics) - self.test_metrics.reset() - - def predict_step(self, *args: Any, **kwargs: Any) -> list[dict[str, Tensor]]: - """Compute and return the predictions. + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> list[dict[str, Tensor]]: + """Compute the predicted bounding boxes. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - list of predicted boxes, labels and scores + Output predicted probabilities. """ - batch = args[0] x = batch["image"] y_hat: list[dict[str, Tensor]] = self(x) return y_hat - - def configure_optimizers(self) -> dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - learning rate dictionary - """ - optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.hyperparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, - mode="max", - patience=self.hyperparams["learning_rate_schedule_patience"], - ), - "monitor": "val_map", - }, - } diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index f5425390bdd..6b059d58c5e 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -6,7 +6,7 @@ import os import warnings from collections.abc import Sequence -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union import kornia.augmentation as K import timm @@ -17,7 +17,6 @@ from lightly.models.modules import MoCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightning import LightningModule from torch import Tensor from torch.optim import SGD, AdamW, Optimizer from torch.optim.lr_scheduler import ( @@ -32,6 +31,7 @@ from ..models import get_weight from . import utils +from .base import BaseTask try: from torch.optim.lr_scheduler import LRScheduler @@ -118,7 +118,7 @@ def moco_augmentations( return aug1, aug2 -class MoCoTask(LightningModule): +class MoCoTask(BaseTask): """MoCo: Momentum Contrast. Reference implementations: @@ -135,6 +135,8 @@ class MoCoTask(LightningModule): .. versionadded:: 0.5 """ + monitor = "train_loss" + def __init__( self, model: str = "resnet50", @@ -160,7 +162,8 @@ def __init__( """Initialize a new MoCoTask instance. Args: - model: Name of the timm model to use. + model: Name of the `timm + `__ model to use. weights: Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. @@ -198,8 +201,6 @@ def __init__( Warns: UserWarning: If hyperparameters do not match MoCo version requested. """ - super().__init__() - # Validate hyperparameters assert version in range(1, 4) if version == 1: @@ -216,13 +217,31 @@ def __init__( if memory_bank_size > 0: warnings.warn("MoCo v3 does not use a memory bank") - self.save_hyperparameters(ignore=["augmentation1", "augmentation2"]) + super().__init__() grayscale_weights = grayscale_weights or torch.ones(in_channels) aug1, aug2 = moco_augmentations(version, size, grayscale_weights) self.augmentation1 = augmentation1 or aug1 self.augmentation2 = augmentation2 or aug2 + def configure_losses(self) -> None: + """Initialize the loss criterion.""" + self.criterion = NTXentLoss( + self.hparams["temperature"], + self.hparams["memory_bank_size"], + self.hparams["gather_distributed"], + ) + + def configure_models(self) -> None: + """Initialize the model.""" + model: str = self.hparams["model"] + weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + in_channels: int = self.hparams["in_channels"] + version: int = self.hparams["version"] + layers: int = self.hparams["layers"] + hidden_dim: int = self.hparams["hidden_dim"] + output_dim: int = self.hparams["output_dim"] + # Create backbone self.backbone = timm.create_model( model, in_chans=in_channels, num_classes=0, pretrained=weights is True @@ -258,12 +277,52 @@ def __init__( output_dim, hidden_dim, output_dim, num_layers=2, batch_norm=batch_norm ) - # Define loss function - self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed) - # Initialize moving average of output self.avg_output_std = 0.0 + def configure_optimizers(self) -> dict[str, Any]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + Optimizer and learning rate scheduler. + """ + if self.hparams["version"] == 3: + optimizer: Optimizer = AdamW( + params=self.parameters(), + lr=self.hparams["lr"], + weight_decay=self.hparams["weight_decay"], + ) + warmup_epochs = 40 + max_epochs = 200 + if self.trainer and self.trainer.max_epochs: + max_epochs = self.trainer.max_epochs + scheduler: LRScheduler = SequentialLR( + optimizer, + schedulers=[ + LinearLR( + optimizer, + start_factor=1 / warmup_epochs, + total_iters=warmup_epochs, + ), + CosineAnnealingLR(optimizer, T_max=max_epochs), + ], + milestones=[warmup_epochs], + ) + else: + optimizer = SGD( + params=self.parameters(), + lr=self.hparams["lr"], + momentum=self.hparams["momentum"], + weight_decay=self.hparams["weight_decay"], + ) + scheduler = MultiStepLR( + optimizer=optimizer, milestones=self.hparams["schedule"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, + } + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Forward pass of the model. @@ -271,15 +330,15 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: x: Mini-batch of images. Returns: - Output from the model and backbone + Output of the model and backbone """ - h = self.backbone(x) + h: Tensor = self.backbone(x) q = h if self.hparams["version"] > 1: q = self.projection_head(q) if self.hparams["version"] == 3: q = self.prediction_head(q) - return cast(Tensor, q), cast(Tensor, h) + return q, h def forward_momentum(self, x: Tensor) -> Tensor: """Forward pass of the momentum model. @@ -290,10 +349,10 @@ def forward_momentum(self, x: Tensor) -> Tensor: Returns: Output from the momentum model. """ - k = self.backbone_momentum(x) + k: Tensor = self.backbone_momentum(x) if self.hparams["version"] > 1: k = self.projection_head_momentum(k) - return cast(Tensor, k) + return k def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -330,7 +389,7 @@ def training_step( with torch.no_grad(): update_momentum(self.backbone, self.backbone_momentum, m) k = self.forward_momentum(x2) - loss = self.criterion(q, k) + loss: Tensor = self.criterion(q, k) elif self.hparams["version"] == 2: q, h1 = self.forward(x1) with torch.no_grad(): @@ -360,7 +419,7 @@ def training_step( self.log("train_ssl_std", self.avg_output_std) self.log("train_loss", loss) - return cast(Tensor, loss) + return loss def validation_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -372,43 +431,3 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" - - def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - Optimizer and learning rate scheduler. - """ - if self.hparams["version"] == 3: - optimizer: Optimizer = AdamW( - params=self.parameters(), - lr=self.hparams["lr"], - weight_decay=self.hparams["weight_decay"], - ) - warmup_epochs = 40 - max_epochs = 200 - if self.trainer and self.trainer.max_epochs: - max_epochs = self.trainer.max_epochs - lr_scheduler: LRScheduler = SequentialLR( - optimizer, - schedulers=[ - LinearLR( - optimizer, - start_factor=1 / warmup_epochs, - total_iters=warmup_epochs, - ), - CosineAnnealingLR(optimizer, T_max=max_epochs), - ], - milestones=[warmup_epochs], - ) - else: - optimizer = SGD( - params=self.parameters(), - lr=self.hparams["lr"], - momentum=self.hparams["momentum"], - weight_decay=self.hparams["weight_decay"], - ) - lr_scheduler = MultiStepLR( - optimizer=optimizer, milestones=self.hparams["schedule"] - ) - return [optimizer], [lr_scheduler] diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index ce62a64947a..2a47937a6a5 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -1,51 +1,120 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Regression tasks.""" +"""Trainers for regression.""" import os -from typing import Any, cast +from typing import Any, Optional, Union import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import timm import torch import torch.nn as nn -from lightning.pytorch import LightningModule from torch import Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torchvision.models._api import WeightsEnum from ..datasets import unbind_samples from ..models import FCN, get_weight from . import utils +from .base import BaseTask + + +class RegressionTask(BaseTask): + """Regression.""" + + target_key = "label" + + def __init__( + self, + model: str = "resnet50", + backbone: str = "resnet50", + weights: Optional[Union[WeightsEnum, str, bool]] = None, + in_channels: int = 3, + num_outputs: int = 1, + num_filters: int = 3, + loss: str = "mse", + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + freeze_decoder: bool = False, + ) -> None: + """Initialize a new RegressionTask instance. + Args: + model: Name of the + `timm `__ or + `smp `__ model to use. + backbone: Name of the + `timm `__ or + `smp `__ backbone + to use. Only applicable to PixelwiseRegressionTask. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False + or None for random weights, or the path to a saved model state dict. + in_channels: Number of input channels to model. + num_outputs: Number of prediction outputs. + num_filters: Number of filters. Only applicable when model='fcn'. + loss: One of 'mse' or 'mae'. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + freeze_backbone: Freeze the backbone network to linear probe + the regression head. Does not support FCN models. + freeze_decoder: Freeze the decoder network to linear probe + the regression head. Does not support FCN models. + Only applicable to PixelwiseRegressionTask. -class RegressionTask(LightningModule): - """LightningModule for training models on regression datasets. + .. versionchanged:: 0.4 + Change regression model support from torchvision.models to timm - Supports any available `Timm model - `_ - as an architecture choice. To see a list of available - models, you can do: + .. versionadded:: 0.5 + The *freeze_backbone* and *freeze_decoder* parameters. - .. code-block:: python + .. versionchanged:: 0.5 + *learning_rate* and *learning_rate_schedule_patience* were renamed to + *lr* and *patience*. + """ + super().__init__() - import timm - print(timm.list_models()) - """ + def configure_losses(self) -> None: + """Initialize the loss criterion. - target_key: str = "label" + Raises: + ValueError: If *loss* is invalid. + """ + loss: str = self.hparams["loss"] + if loss == "mse": + self.criterion: nn.Module = nn.MSELoss() + elif loss == "mae": + self.criterion = nn.L1Loss() + else: + raise ValueError( + f"Loss type '{loss}' is not valid. " + "Currently, supports 'mse' or 'mae' loss." + ) - def config_model(self) -> None: - """Configures the model based on kwargs parameters.""" + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + metrics = MetricCollection( + { + "RMSE": MeanSquaredError(squared=False), + "MSE": MeanSquaredError(squared=True), + "MAE": MeanAbsoluteError(), + } + ) + self.train_metrics = metrics.clone(prefix="train_") + self.val_metrics = metrics.clone(prefix="val_") + self.test_metrics = metrics.clone(prefix="test_") + + def configure_models(self) -> None: + """Initialize the model.""" # Create model - weights = self.hyperparams["weights"] + weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] self.model = timm.create_model( - self.hyperparams["model"], - num_classes=self.hyperparams["num_outputs"], - in_chans=self.hyperparams["in_channels"], + self.hparams["model"], + num_classes=self.hparams["num_outputs"], + in_chans=self.hparams["in_channels"], pretrained=weights is True, ) @@ -60,127 +129,56 @@ def config_model(self) -> None: self.model = utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head - if self.hyperparams.get("freeze_backbone", False): + if self.hparams["freeze_backbone"]: for param in self.model.parameters(): param.requires_grad = False for param in self.model.get_classifier().parameters(): param.requires_grad = True - def config_task(self) -> None: - """Configures the task based on kwargs parameters.""" - self.config_model() - - self.loss: nn.Module - if self.hyperparams["loss"] == "mse": - self.loss = nn.MSELoss() - elif self.hyperparams["loss"] == "mae": - self.loss = nn.L1Loss() - else: - raise ValueError( - f"Loss type '{self.hyperparams['loss']}' is not valid. " - f"Currently, supports 'mse' or 'mae' loss." - ) - - def __init__(self, **kwargs: Any) -> None: - """Initialize a new LightningModule for training simple regression models. - - Keyword Args: - model: Name of the timm model to use - weights: Either a weight enum, the string representation of a weight enum, - True for ImageNet weights, False or None for random weights, - or the path to a saved model state dict. - num_outputs: Number of prediction outputs - in_channels: Number of input channels to model - learning_rate: Learning rate for optimizer - learning_rate_schedule_patience: Patience for learning rate scheduler - freeze_backbone: Freeze the backbone network to linear probe - the regression head. Does not support FCN models. - freeze_decoder: Freeze the decoder network to linear probe - the regression head. Does not support FCN models. - Only applicable to PixelwiseRegressionTask. - - .. versionchanged:: 0.4 - Change regression model support from torchvision.models to timm - - .. versionadded:: 0.5 - The *freeze_backbone* and *freeze_decoder* parameters. - """ - super().__init__() - - # Creates `self.hparams` from kwargs - self.save_hyperparameters() - self.hyperparams = cast(dict[str, Any], self.hparams) - self.config_task() - - self.train_metrics = MetricCollection( - { - "RMSE": MeanSquaredError(squared=False), - "MSE": MeanSquaredError(squared=True), - "MAE": MeanAbsoluteError(), - }, - prefix="train_", - ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") - - def forward(self, *args: Any, **kwargs: Any) -> Any: - """Forward pass of the model. - - Args: - x: tensor of data to run through the model - - Returns: - output from the model - """ - return self.model(*args, **kwargs) - - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the training loss and additional metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - training loss + The loss tensor. """ - batch = args[0] x = batch["image"] - y = batch[self.target_key] + # TODO: remove .to(...) once we have a real pixelwise regression dataset + y = batch[self.target_key].to(torch.float) y_hat = self(x) - if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - - loss: Tensor = self.loss(y_hat, y.to(torch.float)) - self.log("train_loss", loss) # logging to TensorBoard - self.train_metrics(y_hat, y.to(torch.float)) + loss: Tensor = self.criterion(y_hat, y) + self.log("train_loss", loss) + self.train_metrics(y_hat, y) return loss - def on_train_epoch_end(self) -> None: - """Logs epoch-level training metrics.""" - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - - def validation_step(self, *args: Any, **kwargs: Any) -> None: - """Compute validation loss and log example predictions. + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Compute the validation loss and additional metrics. Args: - batch: the output of your DataLoader - batch_idx: the index of this batch + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] - batch_idx = args[1] x = batch["image"] - y = batch[self.target_key] + # TODO: remove .to(...) once we have a real pixelwise regression dataset + y = batch[self.target_key].to(torch.float) y_hat = self(x) - if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - - loss = self.loss(y_hat, y.to(torch.float)) + loss = self.criterion(y_hat, y) self.log("val_loss", loss) - self.val_metrics(y_hat, y.to(torch.float)) + self.val_metrics(y_hat, y) if ( batch_idx < 10 @@ -207,112 +205,81 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: except ValueError: pass - def on_validation_epoch_end(self) -> None: - """Logs epoch level validation metrics.""" - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step(self, *args: Any, **kwargs: Any) -> None: - """Compute test loss. + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Compute the test loss and additional metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] x = batch["image"] - y = batch[self.target_key] + # TODO: remove .to(...) once we have a real pixelwise regression dataset + y = batch[self.target_key].to(torch.float) y_hat = self(x) - if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) - - loss = self.loss(y_hat, y.to(torch.float)) + loss = self.criterion(y_hat, y) self.log("test_loss", loss) - self.test_metrics(y_hat, y.to(torch.float)) + self.test_metrics(y_hat, y) - def on_test_epoch_end(self) -> None: - """Logs epoch level test metrics.""" - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the predictions. + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted regression values. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + Returns: - predicted values + Output predicted probabilities. """ - batch = args[0] x = batch["image"] y_hat: Tensor = self(x) return y_hat - def configure_optimizers(self) -> dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - learning rate dictionary - """ - optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.hyperparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, - patience=self.hyperparams["learning_rate_schedule_patience"], - ), - "monitor": "val_loss", - }, - } - class PixelwiseRegressionTask(RegressionTask): """LightningModule for pixelwise regression of images. - Supports `Segmentation Models Pytorch - `_ - as an architecture choice in combination with any of these - `TIMM backbones `_. - .. versionadded:: 0.5 """ - target_key: str = "mask" + target_key = "mask" - def config_model(self) -> None: - """Configures the model based on kwargs parameters.""" - weights = self.hyperparams["weights"] + def configure_models(self) -> None: + """Initialize the model.""" + weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] - if self.hyperparams["model"] == "unet": + if self.hparams["model"] == "unet": self.model = smp.Unet( - encoder_name=self.hyperparams["backbone"], + encoder_name=self.hparams["backbone"], encoder_weights="imagenet" if weights is True else None, - in_channels=self.hyperparams["in_channels"], + in_channels=self.hparams["in_channels"], classes=1, ) - elif self.hyperparams["model"] == "deeplabv3+": + elif self.hparams["model"] == "deeplabv3+": self.model = smp.DeepLabV3Plus( - encoder_name=self.hyperparams["backbone"], + encoder_name=self.hparams["backbone"], encoder_weights="imagenet" if weights is True else None, - in_channels=self.hyperparams["in_channels"], + in_channels=self.hparams["in_channels"], classes=1, ) - elif self.hyperparams["model"] == "fcn": + elif self.hparams["model"] == "fcn": self.model = FCN( - in_channels=self.hyperparams["in_channels"], + in_channels=self.hparams["in_channels"], classes=1, - num_filters=self.hyperparams["num_filters"], + num_filters=self.hparams["num_filters"], ) else: raise ValueError( - f"Model type '{self.hyperparams['model']}' is not valid. " - f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." + f"Model type '{self.hparams['model']}' is not valid. " + "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) - if self.hyperparams["model"] != "fcn": + if self.hparams["model"] != "fcn": if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) @@ -323,15 +290,17 @@ def config_model(self) -> None: self.model.encoder.load_state_dict(state_dict) # Freeze backbone - if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[ - "model" - ] in ["unet", "deeplabv3+"]: + if self.hparams.get("freeze_backbone", False) and self.hparams["model"] in [ + "unet", + "deeplabv3+", + ]: for param in self.model.encoder.parameters(): param.requires_grad = False # Freeze decoder - if self.hyperparams.get("freeze_decoder", False) and self.hyperparams[ - "model" - ] in ["unet", "deeplabv3+"]: + if self.hparams.get("freeze_decoder", False) and self.hparams["model"] in [ + "unet", + "deeplabv3+", + ]: for param in self.model.decoder.parameters(): param.requires_grad = False diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index e0497de1b9c..8497436d305 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -1,19 +1,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Segmentation tasks.""" +"""Trainers for semantic segmentation.""" import os import warnings -from typing import Any, cast +from typing import Any, Optional, Union import matplotlib.pyplot as plt import segmentation_models_pytorch as smp -import torch import torch.nn as nn -from lightning.pytorch import LightningModule from torch import Tensor -from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex from torchvision.models._api import WeightsEnum @@ -21,74 +18,169 @@ from ..datasets.utils import unbind_samples from ..models import FCN, get_weight from . import utils +from .base import BaseTask + + +class SemanticSegmentationTask(BaseTask): + """Semantic Segmentation.""" + + def __init__( + self, + model: str = "unet", + backbone: str = "resnet50", + weights: Optional[Union[WeightsEnum, str, bool]] = None, + in_channels: int = 3, + num_classes: int = 1000, + num_filters: int = 3, + loss: str = "ce", + class_weights: Optional[Tensor] = None, + ignore_index: Optional[int] = None, + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + freeze_decoder: bool = False, + ) -> None: + """Inititalize a new SemanticSegmentationTask instance. + Args: + model: Name of the + `smp `__ model to use. + backbone: Name of the `timm + `__ or `smp + `__ backbone to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False or + None for random weights, or the path to a saved model state dict. FCN + model does not support pretrained weights. Pretrained ViT weight enums + are not supported yet. + in_channels: Number of input channels to model. + num_classes: Number of prediction classes. + num_filters: Number of filters. Only applicable when model='fcn'. + loss: Name of the loss function, currently supports + 'ce', 'jaccard' or 'focal' loss. + class_weights: Optional rescaling weight given to each + class and used with 'ce' loss. + ignore_index: Optional integer class index to ignore in the loss and + metrics. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + freeze_backbone: Freeze the backbone network to fine-tune the + decoder and segmentation head. + freeze_decoder: Freeze the decoder network to linear probe + the segmentation head. -class SemanticSegmentationTask(LightningModule): - """LightningModule for semantic segmentation of images. + Warns: + UserWarning: When loss='jaccard' and ignore_index is specified. - Supports `Segmentation Models Pytorch - `_ - as an architecture choice in combination with any of these - `TIMM backbones `_. - """ + .. versionchanged:: 0.3 + *ignore_zeros* was renamed to *ignore_index*. - def config_task(self) -> None: - """Configures the task based on kwargs parameters passed to the constructor.""" - weights = self.hyperparams["weights"] + .. versionchanged:: 0.4 + *segmentation_model*, *encoder_name*, and *encoder_weights* + were renamed to *model*, *backbone*, and *weights*. - if self.hyperparams["model"] == "unet": - self.model = smp.Unet( - encoder_name=self.hyperparams["backbone"], - encoder_weights="imagenet" if weights is True else None, - in_channels=self.hyperparams["in_channels"], - classes=self.hyperparams["num_classes"], + .. versionadded: 0.5 + The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. + + .. versionchanged:: 0.5 + The *weights* parameter now supports WeightEnums and checkpoint paths. + *learning_rate* and *learning_rate_schedule_patience* were renamed to + *lr* and *patience*. + """ + if ignore_index is not None and loss == "jaccard": + warnings.warn( + "ignore_index has no effect on training when loss='jaccard'", + UserWarning, ) - elif self.hyperparams["model"] == "deeplabv3+": - self.model = smp.DeepLabV3Plus( - encoder_name=self.hyperparams["backbone"], - encoder_weights="imagenet" if weights is True else None, - in_channels=self.hyperparams["in_channels"], - classes=self.hyperparams["num_classes"], + + super().__init__() + + def configure_losses(self) -> None: + """Initialize the loss criterion. + + Raises: + ValueError: If *loss* is invalid. + """ + loss: str = self.hparams["loss"] + ignore_index = self.hparams["ignore_index"] + if loss == "ce": + ignore_value = -1000 if ignore_index is None else ignore_index + self.criterion = nn.CrossEntropyLoss( + ignore_index=ignore_value, weight=self.hparams["class_weights"] ) - elif self.hyperparams["model"] == "fcn": - self.model = FCN( - in_channels=self.hyperparams["in_channels"], - classes=self.hyperparams["num_classes"], - num_filters=self.hyperparams["num_filters"], + elif loss == "jaccard": + self.criterion = smp.losses.JaccardLoss( + mode="multiclass", classes=self.hparams["num_classes"] + ) + elif loss == "focal": + self.criterion = smp.losses.FocalLoss( + "multiclass", ignore_index=ignore_index, normalized=True ) else: raise ValueError( - f"Model type '{self.hyperparams['model']}' is not valid. " - f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." + f"Loss type '{loss}' is not valid. " + "Currently, supports 'ce', 'jaccard' or 'focal' loss." ) - if self.hyperparams["loss"] == "ce": - ignore_value = -1000 if self.ignore_index is None else self.ignore_index + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + num_classes: int = self.hparams["num_classes"] + ignore_index: Optional[int] = self.hparams["ignore_index"] + metrics = MetricCollection( + [ + MulticlassAccuracy( + num_classes=num_classes, + ignore_index=ignore_index, + multidim_average="global", + average="micro", + ), + MulticlassJaccardIndex( + num_classes=num_classes, ignore_index=ignore_index, average="micro" + ), + ] + ) + self.train_metrics = metrics.clone(prefix="train_") + self.val_metrics = metrics.clone(prefix="val_") + self.test_metrics = metrics.clone(prefix="test_") - class_weights = None - if isinstance(self.class_weights, torch.Tensor): - class_weights = self.class_weights.to(dtype=torch.float32) - elif hasattr(self.class_weights, "__array__") or self.class_weights: - class_weights = torch.tensor(self.class_weights, dtype=torch.float32) + def configure_models(self) -> None: + """Initialize the model. - self.loss = nn.CrossEntropyLoss( - ignore_index=ignore_value, weight=class_weights + Raises: + ValueError: If *model* is invalid. + """ + model: str = self.hparams["model"] + backbone: str = self.hparams["backbone"] + weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + in_channels: int = self.hparams["in_channels"] + num_classes: int = self.hparams["num_classes"] + num_filters: int = self.hparams["num_filters"] + + if model == "unet": + self.model = smp.Unet( + encoder_name=backbone, + encoder_weights="imagenet" if weights is True else None, + in_channels=in_channels, + classes=num_classes, ) - elif self.hyperparams["loss"] == "jaccard": - self.loss = smp.losses.JaccardLoss( - mode="multiclass", classes=self.hyperparams["num_classes"] + elif model == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=backbone, + encoder_weights="imagenet" if weights is True else None, + in_channels=in_channels, + classes=num_classes, ) - elif self.hyperparams["loss"] == "focal": - self.loss = smp.losses.FocalLoss( - "multiclass", ignore_index=self.ignore_index, normalized=True + elif model == "fcn": + self.model = FCN( + in_channels=in_channels, classes=num_classes, num_filters=num_filters ) else: raise ValueError( - f"Loss type '{self.hyperparams['loss']}' is not valid. " - f"Currently, supports 'ce', 'jaccard' or 'focal' loss." + f"Model type '{model}' is not valid. " + "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) - if self.hyperparams["model"] != "fcn": + if model != "fcn": if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) @@ -99,156 +191,53 @@ def config_task(self) -> None: self.model.encoder.load_state_dict(state_dict) # Freeze backbone - if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[ - "model" - ] in ["unet", "deeplabv3+"]: + if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]: for param in self.model.encoder.parameters(): param.requires_grad = False # Freeze decoder - if self.hyperparams.get("freeze_decoder", False) and self.hyperparams[ - "model" - ] in ["unet", "deeplabv3+"]: + if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]: for param in self.model.decoder.parameters(): param.requires_grad = False - def __init__(self, **kwargs: Any) -> None: - """Initialize the LightningModule with a model and loss function. - - Keyword Args: - model: Name of the segmentation model type to use - backbone: Name of the timm backbone to use - weights: Either a weight enum, the string representation of a weight enum, - True for ImageNet weights, False or None for random weights, - or the path to a saved model state dict. FCN model does not support - pretrained weights. Pretrained ViT weight enums are not supported yet. - in_channels: Number of channels in input image - num_classes: Number of semantic classes to predict - loss: Name of the loss function, currently supports - 'ce', 'jaccard' or 'focal' loss - class_weights: Optional rescaling weight given to each - class and used with 'ce' loss - ignore_index: Optional integer class index to ignore in the loss and metrics - learning_rate: Learning rate for optimizer - learning_rate_schedule_patience: Patience for learning rate scheduler - freeze_backbone: Freeze the backbone network to fine-tune the - decoder and segmentation head - freeze_decoder: Freeze the decoder network to linear probe - the segmentation head - - Raises: - ValueError: if kwargs arguments are invalid - - .. versionchanged:: 0.3 - The *ignore_zeros* parameter was renamed to *ignore_index*. - - .. versionchanged:: 0.4 - The *segmentation_model* parameter was renamed to *model*, - *encoder_name* renamed to *backbone*, and - *encoder_weights* to *weights*. - - .. versionadded: 0.5 - The *class_weights*, *freeze_backbone*, - and *freeze_decoder* parameters. - - .. versionchanged:: 0.5 - The *weights* parameter now supports WeightEnums and checkpoint paths. - - """ - super().__init__() - - # Creates `self.hparams` from kwargs - self.save_hyperparameters() - self.hyperparams = cast(dict[str, Any], self.hparams) - - if not isinstance(kwargs["ignore_index"], (int, type(None))): - raise ValueError("ignore_index must be an int or None") - if (kwargs["ignore_index"] is not None) and (kwargs["loss"] == "jaccard"): - warnings.warn( - "ignore_index has no effect on training when loss='jaccard'", - UserWarning, - ) - self.ignore_index = kwargs["ignore_index"] - self.class_weights = kwargs.get("class_weights", None) - - self.config_task() - - self.train_metrics = MetricCollection( - [ - MulticlassAccuracy( - num_classes=self.hyperparams["num_classes"], - ignore_index=self.ignore_index, - multidim_average="global", - average="micro", - ), - MulticlassJaccardIndex( - num_classes=self.hyperparams["num_classes"], - ignore_index=self.ignore_index, - average="micro", - ), - ], - prefix="train_", - ) - self.val_metrics = self.train_metrics.clone(prefix="val_") - self.test_metrics = self.train_metrics.clone(prefix="test_") - - def forward(self, *args: Any, **kwargs: Any) -> Any: - """Forward pass of the model. + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the training loss and additional metrics. Args: - x: tensor of data to run through the model + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - output from the model + The loss tensor. """ - return self.model(*args, **kwargs) - - def training_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the training loss. - - Args: - batch: the output of your DataLoader - - Returns: - training loss - """ - batch = args[0] x = batch["image"] y = batch["mask"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the train step logs every `log_every_n_steps` steps where - # `log_every_n_steps` is a parameter to the `Trainer` object - self.log("train_loss", loss, on_step=True, on_epoch=False) + loss: Tensor = self.criterion(y_hat, y) + self.log("train_loss", loss) self.train_metrics(y_hat_hard, y) + return loss - return cast(Tensor, loss) - - def on_train_epoch_end(self) -> None: - """Logs epoch level training metrics.""" - self.log_dict(self.train_metrics.compute()) - self.train_metrics.reset() - - def validation_step(self, *args: Any, **kwargs: Any) -> None: - """Compute validation loss and log example predictions. + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Compute the validation loss and additional metrics. Args: - batch: the output of your DataLoader - batch_idx: the index of this batch + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] - batch_idx = args[1] x = batch["image"] y = batch["mask"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - self.log("val_loss", loss, on_step=False, on_epoch=True) + loss = self.criterion(y_hat, y) + self.log("val_loss", loss) self.val_metrics(y_hat_hard, y) if ( @@ -273,69 +262,35 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: except ValueError: pass - def on_validation_epoch_end(self) -> None: - """Logs epoch level validation metrics.""" - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - - def test_step(self, *args: Any, **kwargs: Any) -> None: - """Compute test loss. + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Compute the test loss and additional metrics. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. """ - batch = args[0] x = batch["image"] y = batch["mask"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) - - loss = self.loss(y_hat, y) - - # by default, the test and validation steps only log per *epoch* - self.log("test_loss", loss, on_step=False, on_epoch=True) + loss = self.criterion(y_hat, y) + self.log("test_loss", loss) self.test_metrics(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - """Logs epoch level test metrics.""" - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: - """Compute and return the predictions. - - By default, this will loop over images in a dataloader and aggregate - predictions into a list. This may not be desirable if you have many images - or large images which could cause out of memory errors. In this case - it's recommended to override this with a custom predict_step. + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted class probabilities. Args: - batch: the output of your DataLoader + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: - predicted softmax probabilities + Output predicted probabilities. """ - batch = args[0] x = batch["image"] y_hat: Tensor = self(x).softmax(dim=1) return y_hat - - def configure_optimizers(self) -> dict[str, Any]: - """Initialize the optimizer and learning rate scheduler. - - Returns: - learning rate dictionary - """ - optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.hyperparams["learning_rate"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": ReduceLROnPlateau( - optimizer, - patience=self.hyperparams["learning_rate_schedule_patience"], - ), - "monitor": "val_loss", - }, - } diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 1dd546f7c26..ca4ab679713 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -5,7 +5,7 @@ import os import warnings -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union import kornia.augmentation as K import timm @@ -14,9 +14,8 @@ import torch.nn.functional as F from lightly.loss import NTXentLoss from lightly.models.modules import SimCLRProjectionHead -from lightning import LightningModule from torch import Tensor -from torch.optim import Adam, Optimizer +from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from torchvision.models._api import WeightsEnum @@ -24,11 +23,7 @@ from ..models import get_weight from . import utils - -try: - from torch.optim.lr_scheduler import LRScheduler -except ImportError: - from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from .base import BaseTask def simclr_augmentations(size: int, weights: Tensor) -> nn.Module: @@ -57,7 +52,7 @@ def simclr_augmentations(size: int, weights: Tensor) -> nn.Module: ) -class SimCLRTask(LightningModule): +class SimCLRTask(BaseTask): """SimCLR: a simple framework for contrastive learning of visual representations. Reference implementation: @@ -72,6 +67,8 @@ class SimCLRTask(LightningModule): .. versionadded:: 0.5 """ + monitor = "train_loss" + def __init__( self, model: str = "resnet50", @@ -93,7 +90,8 @@ def __init__( """Initialize a new SimCLRTask instance. Args: - model: Name of the timm model to use. + model: Name of the `timm + `__ model to use. weights: Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. @@ -122,8 +120,6 @@ def __init__( Warns: UserWarning: If hyperparameters do not match SimCLR version requested. """ - super().__init__() - # Validate hyperparameters assert version in range(1, 3) if version == 1: @@ -137,16 +133,33 @@ def __init__( if memory_bank_size == 0: warnings.warn("SimCLR v2 uses a memory bank") - self.save_hyperparameters(ignore=["augmentations"]) + super().__init__() grayscale_weights = grayscale_weights or torch.ones(in_channels) self.augmentations = augmentations or simclr_augmentations( size, grayscale_weights ) + def configure_losses(self) -> None: + """Initialize the loss criterion.""" + self.criterion = NTXentLoss( + self.hparams["temperature"], + self.hparams["memory_bank_size"], + self.hparams["gather_distributed"], + ) + + def configure_models(self) -> None: + """Initialize the model.""" + weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"] + hidden_dim: int = self.hparams["hidden_dim"] + output_dim: int = self.hparams["output_dim"] + # Create backbone self.backbone = timm.create_model( - model, in_chans=in_channels, num_classes=0, pretrained=weights is True + self.hparams["model"], + in_chans=self.hparams["in_channels"], + num_classes=0, + pretrained=weights is True, ) # Load weights @@ -167,12 +180,9 @@ def __init__( output_dim = input_dim self.projection_head = SimCLRProjectionHead( - input_dim, hidden_dim, output_dim, layers + input_dim, hidden_dim, output_dim, self.hparams["layers"] ) - # Define loss function - self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed) - # Initialize moving average of output self.avg_output_std = 0.0 @@ -187,11 +197,11 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: x: Mini-batch of images. Returns: - Output from the model and backbone. + Output of the model and backbone. """ - h = self.backbone(x) # shape of batch_size x num_features + h: Tensor = self.backbone(x) # shape of batch_size x num_features z = self.projection_head(h) - return cast(Tensor, z), cast(Tensor, h) + return z, h def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -205,10 +215,13 @@ def training_step( Returns: The loss tensor. + + Raises: + AssertionError: If channel dimensions are incorrect. """ x = batch["image"] - in_channels = self.hparams["in_channels"] + in_channels: int = self.hparams["in_channels"] assert x.size(1) == in_channels or x.size(1) == 2 * in_channels if x.size(1) == in_channels: @@ -225,7 +238,7 @@ def training_step( z1, h1 = self(x1) z2, h2 = self(x2) - loss = self.criterion(z1, z2) + loss: Tensor = self.criterion(z1, z2) # Calculate the mean normalized standard deviation over features dimensions. # If this is << 1 / sqrt(h1.shape[1]), then the model is not learning anything. @@ -238,7 +251,7 @@ def training_step( self.log("train_ssl_std", self.avg_output_std) self.log("train_loss", loss) - return cast(Tensor, loss) + return loss def validation_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -253,7 +266,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" - def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]: + def configure_optimizers(self) -> dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: @@ -272,7 +285,7 @@ def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]: warmup_epochs = 10 else: warmup_epochs = int(max_epochs * 0.05) - lr_scheduler = SequentialLR( + scheduler = SequentialLR( optimizer, schedulers=[ LinearLR(optimizer, total_iters=warmup_epochs), @@ -280,4 +293,7 @@ def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]: ], milestones=[warmup_epochs], ) - return [optimizer], [lr_scheduler] + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, + }