Skip to content

Commit 5ae4601

Browse files
committed
add metrics
Signed-off-by: Zhiyuan Chen <[email protected]>
1 parent 83ae20a commit 5ae4601

File tree

1 file changed

+227
-0
lines changed

1 file changed

+227
-0
lines changed

torcheval/metrics/metrics.py

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# pylint: disable=E1101,W0622
2+
3+
from __future__ import annotations
4+
5+
from functools import partial
6+
from math import nan
7+
from typing import Any, Callable, Iterable
8+
9+
try:
10+
from functools import cached_property # type: ignore
11+
except ImportError:
12+
from functools import lru_cache
13+
14+
def cached_property(f): # type: ignore
15+
return property(lru_cache()(f))
16+
17+
18+
import torch
19+
from chanfig import FlatDict
20+
from torch import Tensor
21+
from torch import distributed as dist
22+
from .metric import Metric
23+
from . import functional as F
24+
25+
26+
class flist(list): # pylint: disable=R0903
27+
def __format__(self, *args, **kwargs):
28+
return " ".join([x.__format__(*args, **kwargs) for x in self])
29+
30+
31+
class Metrics(Metric):
32+
r"""
33+
Metric class wraps around multiple metrics that share the same states.
34+
35+
Typically, there are many metrics that we want to compute for a single task.
36+
For example, we usually needs to compute `accuracy`, `auroc`, `auprc` for a classification task.
37+
Computing them one by one is inefficient, especially when evaluating in a distributed environment.
38+
39+
To solve this problem, Metrics maintains a shared state for multiple metric functions.
40+
41+
Attributes:
42+
metrics: A dictionary of metrics to be computed.
43+
input: The input tensor of latest batch.
44+
target: The target tensor of latest batch.
45+
inputs: All input tensors.
46+
targets: All target tensors.
47+
48+
Args:
49+
*args: A single mapping of metrics.
50+
**metrics: Metrics.
51+
"""
52+
53+
metrics: FlatDict[str, Callable]
54+
_input: Tensor
55+
_target: Tensor
56+
_inputs: list[Tensor]
57+
_targets: list[Tensor]
58+
_input_buffer: list[Tensor]
59+
_target_buffer: list[Tensor]
60+
index: str
61+
best_fn: Callable
62+
63+
def __init__(self, *args, **metrics: FlatDict[str, Callable]):
64+
super().__init__()
65+
self._add_state("_input", torch.empty(0))
66+
self._add_state("_target", torch.empty(0))
67+
self._add_state("_inputs", [])
68+
self._add_state("_targets", [])
69+
self._add_state("_input_buffer", [])
70+
self._add_state("_target_buffer", [])
71+
self.metrics = FlatDict(*args, **metrics)
72+
73+
@torch.inference_mode()
74+
def update(self, input: Any, target: Any) -> None:
75+
if not isinstance(input, torch.Tensor):
76+
input = torch.tensor(input)
77+
if not isinstance(target, torch.Tensor):
78+
target = torch.tensor(target)
79+
input, target = input.to(self.device), target.to(self.device)
80+
self._input, self._target = input, target
81+
self._input_buffer.append(input)
82+
self._target_buffer.append(target)
83+
84+
def compute(self) -> FlatDict[str, float]:
85+
return self.comp
86+
87+
def value(self) -> FlatDict[str, float]:
88+
return self.val
89+
90+
def average(self) -> FlatDict[str, float]:
91+
return self.avg
92+
93+
@cached_property
94+
def comp(self) -> FlatDict[str, float]:
95+
return self._compute(self._input, self._target)
96+
97+
@cached_property
98+
def val(self) -> FlatDict[str, float]:
99+
return self._compute(self.input, self.target)
100+
101+
@cached_property
102+
def avg(self) -> FlatDict[str, float]:
103+
return self._compute(self.inputs, self.targets)
104+
105+
@torch.inference_mode()
106+
def _compute(self, input: Tensor, target: Tensor) -> flist | float:
107+
if input.numel() == 0 == target.numel():
108+
return FlatDict({name: nan for name in self.metrics.keys()})
109+
ret = FlatDict()
110+
for name, metric in self.metrics.items():
111+
score = metric(input, target)
112+
ret[name] = score.item() if score.numel() == 1 else flist(score.tolist())
113+
return ret
114+
115+
@torch.inference_mode()
116+
def merge_state(self, metrics: Iterable):
117+
raise NotImplementedError()
118+
119+
@cached_property
120+
@torch.inference_mode()
121+
def input(self):
122+
if not dist.is_initialized() or dist.get_world_size() == 1:
123+
return self._input
124+
synced_input = [torch.zeros_like(self._input) for _ in range(dist.get_world_size())]
125+
dist.all_gather(synced_input, self._input)
126+
return torch.cat([t.to(self.device) for t in synced_input], 0)
127+
128+
@cached_property
129+
@torch.inference_mode()
130+
def target(self):
131+
if not dist.is_initialized() or dist.get_world_size() == 1:
132+
return self._target
133+
synced_target = [torch.zeros_like(self._target) for _ in range(dist.get_world_size())]
134+
dist.all_gather(synced_target, self._target)
135+
return torch.cat([t.to(self.device) for t in synced_target], 0)
136+
137+
@cached_property
138+
@torch.inference_mode()
139+
def inputs(self):
140+
if not self._inputs:
141+
return torch.empty(0)
142+
if self._input_buffer and dist.is_initialized() and dist.get_world_size() > 1:
143+
synced_inputs = [None for _ in range(dist.get_world_size())]
144+
dist.all_gather_object(synced_inputs, self._input_buffer)
145+
self._inputs.extend(synced_inputs)
146+
return torch.cat(self._inputs, 0)
147+
148+
@cached_property
149+
@torch.inference_mode()
150+
def targets(self):
151+
if not self._targets:
152+
return torch.empty(0)
153+
if self._target_buffer and dist.is_initialized() and dist.get_world_size() > 1:
154+
synced_targets = [None for _ in range(dist.get_world_size())]
155+
dist.all_gather_object(synced_targets, self._target_buffer)
156+
self._targets.extend(synced_targets)
157+
return torch.cat(self._targets, 0)
158+
159+
def __repr__(self):
160+
keys = tuple(i for i in self.metrics.keys())
161+
return f"{self.__class__.__name__}{keys}"
162+
163+
def __format__(self, format_spec):
164+
val, avg = self.compute(), self.average()
165+
return "\n".join(
166+
[f"{key}: {val[key].__format__(format_spec)} ({avg[key].__format__(format_spec)})" for key in self.metrics]
167+
)
168+
169+
170+
class IndexMetrics(Metrics):
171+
r"""
172+
IndexMetrics is a subclass of Metrics that supports scoring.
173+
174+
Score is a single value that best represents the performance of the model.
175+
It is the core metrics that we use to compare different models.
176+
For example, in classification, we usually use auroc as the score.
177+
178+
IndexMetrics requires two additional arguments: `index` and `best_fn`.
179+
`index` is the name of the metric that we use to compute the score.
180+
`best_fn` is a function that takes a list of values and returns the best value.
181+
`best_fn` is only not used by IndexMetrics, it is meant to be accessed by other classes.
182+
183+
Attributes:
184+
index: The name of the metric that we use to compute the score.
185+
best_fn: A function that takes a list of values and returns the best value.
186+
187+
Args:
188+
*args: A single mapping of metrics.
189+
index: The name of the metric that we use to compute the score. Defaults to the first metric.
190+
best_fn: A function that takes a list of values and returns the best value. Defaults to `max`.
191+
**metrics: Metrics.
192+
"""
193+
194+
index: str
195+
best_fn: Callable
196+
197+
def __init__(
198+
self, *args, index: str | None = None, best_fn: Callable | None = max, **metrics: FlatDict[str, Callable]
199+
):
200+
super().__init__(*args, **metrics)
201+
self.index = index or next(iter(self.metrics.keys()))
202+
self.metric = self.metrics[self.index]
203+
self.best_fn = best_fn or max
204+
205+
def score(self, scope: str) -> float | flist:
206+
if scope == "batch":
207+
return self.batch_score()
208+
if scope == "average":
209+
return self.average_score()
210+
raise ValueError(f"Unknown scope: {scope}")
211+
212+
def batch_score(self) -> float | flist:
213+
return self.calculate(self.metric, self.input, self.target)
214+
215+
def average_score(self) -> float | flist:
216+
return self.calculate(self.metric, self.inputs, self.targets)
217+
218+
219+
def binary_metrics():
220+
return Metrics(auroc=F.binary_auroc, auprc=F.binary_auprc, acc=F.binary_accuracy)
221+
222+
223+
def multiclass_metrics(num_classes: int):
224+
auroc = partial(F.multiclass_auroc, num_classes=num_classes)
225+
auprc = partial(F.multiclass_auprc, num_classes=num_classes)
226+
acc = partial(F.multiclass_accuracy, num_classes=num_classes)
227+
return Metrics(auroc=auroc, auprc=auprc, acc=acc)

0 commit comments

Comments
 (0)