Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit dc58203

Browse files
authored
Adopt torchmetrics (#4290)
1 parent 8fc555a commit dc58203

File tree

5 files changed

+23
-19
lines changed

5 files changed

+23
-19
lines changed

dependencies/recommended.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
88
torch == 1.9.0 ; sys_platform == "darwin"
99
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
1010
torchvision == 0.10.0 ; sys_platform == "darwin"
11-
pytorch-lightning >= 1.4.2
11+
pytorch-lightning >= 1.5
12+
torchmetrics
1213
onnx
1314
peewee
1415
graphviz

dependencies/recommended_legacy.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
66
# It will install pytorch-lightning 0.8.x and unit tests won't work.
77
# Latest version has conflict with tensorboard and tensorflow 1.x.
88
pytorch-lightning
9+
torchmetrics
910

1011
keras == 2.1.6
1112
onnx

nni/retiarii/evaluator/pytorch/cgo/accelerator.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import Any, Union, Optional, List
2-
import torch
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from typing import Any, List, Optional, Union
35

6+
import torch
47
from pytorch_lightning.accelerators.accelerator import Accelerator
8+
from pytorch_lightning.plugins.environments import ClusterEnvironment
59
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
6-
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
710
from pytorch_lightning.trainer import Trainer
8-
9-
from pytorch_lightning.plugins import Plugin
10-
from pytorch_lightning.plugins.environments import ClusterEnvironment
11+
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
1112

1213
from ....serializer import serialize_cls
1314

@@ -69,9 +70,8 @@ def model_to_device(self) -> None:
6970
# bypass device placement from pytorch lightning
7071
pass
7172

72-
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
73-
self.model_to_device()
74-
return self.model
73+
def setup(self) -> None:
74+
pass
7575

7676
@property
7777
def is_global_zero(self) -> bool:
@@ -100,8 +100,9 @@ def get_accelerator_connector(
100100
deterministic: bool = False,
101101
precision: int = 32,
102102
amp_backend: str = 'native',
103-
amp_level: str = 'O2',
104-
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
103+
amp_level: Optional[str] = None,
104+
plugins: Optional[Union[List[Union[TrainingTypePlugin, ClusterEnvironment, str]],
105+
TrainingTypePlugin, ClusterEnvironment, str]] = None,
105106
**other_trainier_kwargs) -> AcceleratorConnector:
106107
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
107108
return AcceleratorConnector(

nni/retiarii/evaluator/pytorch/cgo/evaluator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import torch.nn as nn
99
import torch.optim as optim
10-
import pytorch_lightning as pl
10+
import torchmetrics
1111
from torch.utils.data import DataLoader
1212

1313
import nni
@@ -19,7 +19,7 @@
1919

2020
@serialize_cls
2121
class _MultiModelSupervisedLearningModule(LightningModule):
22-
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
22+
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
2323
n_models: int = 0,
2424
learning_rate: float = 0.001,
2525
weight_decay: float = 0.,
@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
119119
Class for optimizer (not an instance). default: ``Adam``
120120
"""
121121

122-
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
122+
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
123123
learning_rate: float = 0.001,
124124
weight_decay: float = 0.,
125125
optimizer: optim.Optimizer = optim.Adam):
@@ -180,7 +180,7 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
180180
learning_rate: float = 0.001,
181181
weight_decay: float = 0.,
182182
optimizer: optim.Optimizer = optim.Adam):
183-
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
183+
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
184184
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
185185

186186

nni/retiarii/evaluator/pytorch/lightning.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytorch_lightning as pl
1010
import torch.nn as nn
1111
import torch.optim as optim
12+
import torchmetrics
1213
from torch.utils.data import DataLoader
1314

1415
import nni
@@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
140141
### The following are some commonly used Lightning modules ###
141142

142143
class _SupervisedLearningModule(LightningModule):
143-
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
144+
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
144145
learning_rate: float = 0.001,
145146
weight_decay: float = 0.,
146147
optimizer: optim.Optimizer = optim.Adam,
@@ -213,7 +214,7 @@ def _get_validation_metrics(self):
213214
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
214215

215216

216-
class _AccuracyWithLogits(pl.metrics.Accuracy):
217+
class _AccuracyWithLogits(torchmetrics.Accuracy):
217218
def update(self, pred, target):
218219
return super().update(nn.functional.softmax(pred), target)
219220

@@ -278,7 +279,7 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
278279
weight_decay: float = 0.,
279280
optimizer: optim.Optimizer = optim.Adam,
280281
export_onnx: bool = True):
281-
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
282+
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
282283
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
283284
export_onnx=export_onnx)
284285

0 commit comments

Comments
 (0)