-
Notifications
You must be signed in to change notification settings - Fork 412
/
collections.py
409 lines (349 loc) · 17.7 KB
/
collections.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, Dict, Hashable, Iterable, List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module, ModuleDict
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import _flatten_dict
# this is just a bypass for this module name collision with build-in one
from torchmetrics.utilities.imports import OrderedDict
class MetricCollection(ModuleDict):
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
Args:
metrics: One of the following
* list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name
as key for output dict. Therefore, two metrics of the same class cannot be chained this way.
* arguments: similar to passing in as a list, metrics passed in as arguments will use their metric
class name as key for the output dict.
* dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict.
Use this format if you want to chain together multiple of the same metric with different parameters.
Note that the keys in the output dict will be sorted alphabetically.
prefix: a string to append in front of the keys of the output dict
postfix: a string to append after the keys of the output dict
compute_groups:
By default the MetricCollection will try to reduce the computations needed for the metrics in the collection
by checking if they belong to the same **compute group**. All metrics in a compute group share the same
metric state and are therefore only different in their compute step e.g. accuracy, precision and recall
can all be computed from the true positives/negatives and false positives/negatives. By default,
this argument is ``True`` which enables this feature. Set this argument to `False` for disabling
this behaviour. Can also be set to a list of list of metrics for setting the compute groups yourself.
.. note::
Metric collections can be nested at initilization (see last example) but the output of the collection will
still be a single flattened dictionary combining the prefix and postfix arguments from the nested collection.
Raises:
ValueError:
If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
ValueError:
If two elements in ``metrics`` have the same ``name``.
ValueError:
If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``.
ValueError:
If ``metrics`` is ``dict`` and additional_metrics are passed in.
ValueError:
If ``prefix`` is set and it is not a string.
ValueError:
If ``postfix`` is set and it is not a string.
Example (input as list):
>>> import torch
>>> from pprint import pprint
>>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall, MeanSquaredError
>>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
>>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
>>> metrics = MetricCollection([Accuracy(),
... Precision(num_classes=3, average='macro'),
... Recall(num_classes=3, average='macro')])
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
Example (input as arguments):
>>> metrics = MetricCollection(Accuracy(), Precision(num_classes=3, average='macro'),
... Recall(num_classes=3, average='macro'))
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
Example (input as dict):
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
... 'macro_recall': Recall(num_classes=3, average='macro')})
>>> same_metric = metrics.clone()
>>> pprint(metrics(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
>>> pprint(same_metric(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
Example (specification of compute groups):
>>> metrics = MetricCollection(
... Accuracy(),
... Precision(num_classes=3, average='macro'),
... MeanSquaredError(),
... compute_groups=[['Accuracy', 'Precision'], ['MeanSquaredError']]
... )
>>> pprint(metrics(preds, target))
{'Accuracy': tensor(0.1250), 'MeanSquaredError': tensor(2.3750), 'Precision': tensor(0.0667)}
Example (nested metric collections):
>>> metrics = MetricCollection([
... MetricCollection([
... Accuracy(num_classes=3, average='macro'),
... Precision(num_classes=3, average='macro')
... ], postfix='_macro'),
... MetricCollection([
... Accuracy(num_classes=3, average='micro'),
... Precision(num_classes=3, average='micro')
... ], postfix='_micro'),
... ], prefix='valmetrics/')
>>> pprint(metrics(preds, target)) # doctest: +NORMALIZE_WHITESPACE
{'valmetrics/Accuracy_macro': tensor(0.1111),
'valmetrics/Accuracy_micro': tensor(0.1250),
'valmetrics/Precision_macro': tensor(0.0667),
'valmetrics/Precision_micro': tensor(0.1250)}
"""
_groups: Dict[int, List[str]]
def __init__(
self,
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
*additional_metrics: Metric,
prefix: Optional[str] = None,
postfix: Optional[str] = None,
compute_groups: Union[bool, List[List[str]]] = True,
) -> None:
super().__init__()
self.prefix = self._check_arg(prefix, "prefix")
self.postfix = self._check_arg(postfix, "postfix")
self._enable_compute_groups = compute_groups
self._groups_checked: bool = False
self.add_metrics(metrics, *additional_metrics)
@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Iteratively call forward for each metric.
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}
def update(self, *args: Any, **kwargs: Any) -> None:
"""Iteratively call update for each metric.
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
# Use compute groups if already initialized and checked
if self._groups_checked:
for _, cg in self._groups.items():
# only update the first member
m0 = getattr(self, cg[0])
m0.update(*args, **m0._filter_kwargs(**kwargs))
else: # the first update always do per metric to form compute groups
for _, m in self.items(keep_base=True):
m_kwargs = m._filter_kwargs(**kwargs)
m.update(*args, **m_kwargs)
if self._enable_compute_groups:
self._merge_compute_groups()
self._groups_checked = True
def _merge_compute_groups(self) -> None:
"""Iterates over the collection of metrics, checking if the state of each metric matches another.
If so, their compute groups will be merged into one
"""
n_groups = len(self._groups)
while True:
for cg_idx1, cg_members1 in deepcopy(self._groups).items():
for cg_idx2, cg_members2 in deepcopy(self._groups).items():
if cg_idx1 == cg_idx2:
continue
metric1 = getattr(self, cg_members1[0])
metric2 = getattr(self, cg_members2[0])
if self._equal_metric_states(metric1, metric2):
self._groups[cg_idx1].extend(self._groups.pop(cg_idx2))
break
# Start over if we merged groups
if len(self._groups) != n_groups:
break
# Stop when we iterate over everything and do not merge any groups
if len(self._groups) == n_groups:
break
else:
n_groups = len(self._groups)
# Re-index groups
temp = deepcopy(self._groups)
self._groups = {}
for idx, values in enumerate(temp.values()):
self._groups[idx] = values
@staticmethod
def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
"""Check if the metric state of two metrics are the same."""
# empty state
if len(metric1._defaults) == 0 or len(metric2._defaults) == 0:
return False
if metric1._defaults.keys() != metric2._defaults.keys():
return False
for key in metric1._defaults.keys():
state1 = getattr(metric1, key)
state2 = getattr(metric2, key)
if type(state1) != type(state2):
return False
if isinstance(state1, Tensor) and isinstance(state2, Tensor):
return state1.shape == state2.shape and torch.allclose(state1, state2)
if isinstance(state1, list) and isinstance(state2, list):
return all(s1.shape == s2.shape and torch.allclose(s1, s2) for s1, s2 in zip(state1, state2))
return True
def compute(self) -> Dict[str, Any]:
"""Compute the result for each metric in the collection."""
if self._enable_compute_groups and self._groups_checked:
for _, cg in self._groups.items():
m0 = getattr(self, cg[0])
# copy the state to the remaining metrics in the compute group
for i in range(1, len(cg)):
mi = getattr(self, cg[i])
for state in m0._defaults:
setattr(mi, state, getattr(m0, state))
res = {k: m.compute() for k, m in self.items(keep_base=True)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}
def reset(self) -> None:
"""Iteratively call reset for each metric."""
for _, m in self.items(keep_base=True):
m.reset()
def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection":
"""Make a copy of the metric collection
Args:
prefix: a string to append in front of the metric keys
postfix: a string to append after the keys of the output dict
"""
mc = deepcopy(self)
if prefix:
mc.prefix = self._check_arg(prefix, "prefix")
if postfix:
mc.postfix = self._check_arg(postfix, "postfix")
return mc
def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to its state_dict."""
for _, m in self.items(keep_base=True):
m.persistent(mode)
def add_metrics(
self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], *additional_metrics: Metric
) -> None:
"""Add new metrics to Metric Collection."""
if isinstance(metrics, Metric):
# set compatible with original type expectations
metrics = [metrics]
if isinstance(metrics, Sequence):
# prepare for optional additions
metrics = list(metrics)
remain: list = []
for m in additional_metrics:
(metrics if isinstance(m, Metric) else remain).append(m)
if remain:
rank_zero_warn(
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored."
)
elif additional_metrics:
raise ValueError(
f"You have passes extra arguments {additional_metrics} which are not compatible"
f" with first passed dictionary {metrics} so they will be ignored."
)
if isinstance(metrics, dict):
# Check all values are metrics
# Make sure that metrics are added in deterministic order
for name in sorted(metrics.keys()):
metric = metrics[name]
if not isinstance(metric, (Metric, MetricCollection)):
raise ValueError(
f"Value {metric} belonging to key {name} is not an instance of"
" `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
)
if isinstance(metric, Metric):
self[name] = metric
else:
for k, v in metric.items(keep_base=False):
self[f"{name}_{k}"] = v
elif isinstance(metrics, Sequence):
for metric in metrics:
if not isinstance(metric, (Metric, MetricCollection)):
raise ValueError(
f"Input {metric} to `MetricCollection` is not a instance of"
" `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
)
if isinstance(metric, Metric):
name = metric.__class__.__name__
if name in self:
raise ValueError(f"Encountered two metrics both named {name}")
self[name] = metric
else:
for k, v in metric.items(keep_base=False):
self[k] = v
else:
raise ValueError("Unknown input to MetricCollection.")
self._groups_checked = False
if self._enable_compute_groups:
self._init_compute_groups()
else:
self._groups = {}
def _init_compute_groups(self) -> None:
"""Initialize compute groups.
If user provided a list, we check that all metrics in the list are also in the collection. If set to `True` we
simply initialize each metric in the collection as its own group
"""
if isinstance(self._enable_compute_groups, list):
self._groups = {i: k for i, k in enumerate(self._enable_compute_groups)}
for v in self._groups.values():
for metric in v:
if metric not in self:
raise ValueError(
f"Input {metric} in `compute_groups` argument does not match a metric in the collection."
f" Please make sure that {self._enable_compute_groups} matches {self.keys(keep_base=True)}"
)
self._groups_checked = True
else:
# Initialize all metrics as their own compute group
self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))}
@property
def compute_groups(self) -> Dict[int, List[str]]:
"""Return a dict with the current compute groups in the collection."""
return self._groups
def _set_name(self, base: str) -> str:
"""Adjust name of metric with both prefix and postfix."""
name = base if self.prefix is None else self.prefix + base
name = name if self.postfix is None else name + self.postfix
return name
def _to_renamed_ordered_dict(self) -> OrderedDict:
od = OrderedDict()
for k, v in self._modules.items():
od[self._set_name(k)] = v
return od
def keys(self, keep_base: bool = False) -> Iterable[Hashable]:
r"""Return an iterable of the ModuleDict key.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
"""
if keep_base:
return self._modules.keys()
return self._to_renamed_ordered_dict().keys()
def items(self, keep_base: bool = False) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
"""
if keep_base:
return self._modules.items()
return self._to_renamed_ordered_dict().items()
@staticmethod
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
if arg is None or isinstance(arg, str):
return arg
raise ValueError(f"Expected input `{name}` to be a string, but got {type(arg)}")
def __repr__(self) -> str:
repr_str = super().__repr__()[:-2]
if self.prefix:
repr_str += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
if self.postfix:
repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
return repr_str + "\n)"