Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds average argument to AveragePrecision metric #477

Merged
merged 15 commits into from
Sep 6, 2021
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))


- Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437))


- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))



### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))


### Deprecated

Expand All @@ -26,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495))


## [0.5.1] - 2021-08-30

### Added
Expand Down
74 changes: 58 additions & 16 deletions tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
seed_all(42)


def _sk_average_precision_score(y_true, probas_pred, num_classes=1):
def _sk_average_precision_score(y_true, probas_pred, num_classes=1, average=None):
if num_classes == 1:
return sk_average_precision_score(y_true, probas_pred)

Expand All @@ -39,33 +39,41 @@ def _sk_average_precision_score(y_true, probas_pred, num_classes=1):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i]))

if average == "macro":
return np.array(res).mean()
elif average == "weighted":
weights = np.bincount(y_true) if y_true.max() > 1 else y_true.sum(axis=0)
weights = weights / sum(weights)
return (np.array(res) * weights).sum()

return res


def _sk_avg_prec_binary_prob(preds, target, num_classes=1):
def _sk_avg_prec_binary_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average)


def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1):
def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1).numpy()

return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average)


def _sk_avg_prec_multilabel_prob(preds, target, num_classes):
def _sk_avg_prec_multilabel_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1, num_classes).numpy()
return sk_average_precision_score(sk_target, sk_preds, average=None)
return sk_average_precision_score(sk_target, sk_preds, average=average)


def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=None):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average)


@pytest.mark.parametrize(
Expand All @@ -77,30 +85,37 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
(_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES),
],
)
@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
class TestAveragePrecision(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step):
if target.max() > 1 and average == "micro":
pytest.skip("average=micro and multiclass input cannot be used together")

self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=AveragePrecision,
sk_metric=partial(sk_metric, num_classes=num_classes),
sk_metric=partial(sk_metric, num_classes=num_classes, average=average),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes},
metric_args={"num_classes": num_classes, "average": average},
)

def test_average_precision_functional(self, preds, target, sk_metric, num_classes):
def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average):
if target.max() > 1 and average == "micro":
pytest.skip("average=micro and multiclass input cannot be used together")

self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=average_precision,
sk_metric=partial(sk_metric, num_classes=num_classes),
metric_args={"num_classes": num_classes},
sk_metric=partial(sk_metric, num_classes=num_classes, average=average),
metric_args={"num_classes": num_classes, "average": average},
)

def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes):
def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes, average):
self.run_differentiability_test(
preds=preds,
target=target,
Expand All @@ -126,3 +141,30 @@ def test_average_precision_differentiability(self, preds, sk_metric, target, num
)
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score


def test_average_precision_warnings_and_errors():
"""Test that the correct errors and warnings gets raised."""

# check average argument
with pytest.raises(ValueError, match="Expected argument `average` to be one .*"):
AveragePrecision(num_classes=5, average="samples")

# check that micro average cannot be used with multilabel input
pred = tensor(
[
[0.75, 0.05, 0.05, 0.05, 0.05],
[0.05, 0.75, 0.05, 0.05, 0.05],
[0.05, 0.05, 0.75, 0.05, 0.05],
[0.05, 0.05, 0.05, 0.75, 0.05],
]
)
target = tensor([0, 1, 3, 2])
average_precision = AveragePrecision(num_classes=5, average="micro")
with pytest.raises(ValueError, match="Cannot use `micro` average with multi-class input"):
average_precision(pred, target)

# check that warning is thrown when average=macro and nan is encoutered in individual scores
average_precision = AveragePrecision(num_classes=5, average="macro")
with pytest.warns(UserWarning, match="Average precision score for one or more classes was `nan`.*"):
average_precision(pred, target)
24 changes: 21 additions & 3 deletions torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class AveragePrecision(Metric):
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
average:
defines the reduction that is applied in the case of multiclass and multilabel input.
Should be one of the following:

- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be
used with multiclass input.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support.
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.

compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -66,7 +79,7 @@ class AveragePrecision(Metric):
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = AveragePrecision(num_classes=5)
>>> average_precision = AveragePrecision(num_classes=5, average=None)
>>> average_precision(pred, target)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""
Expand All @@ -78,6 +91,7 @@ def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -90,6 +104,10 @@ def __init__(

self.num_classes = num_classes
self.pos_label = pos_label
allowed_average = ("micro", "macro", "weighted", None)
if average not in allowed_average:
raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}")
self.average = average

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
Expand All @@ -107,7 +125,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
target: Ground truth values
"""
preds, target, num_classes, pos_label = _average_precision_update(
preds, target, self.num_classes, self.pos_label
preds, target, self.num_classes, self.pos_label, self.average
)
self.preds.append(preds)
self.target.append(target)
Expand All @@ -125,7 +143,7 @@ def compute(self) -> Union[Tensor, List[Tensor]]:
target = dim_zero_cat(self.target)
if not self.num_classes:
raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
return _average_precision_compute(preds, target, self.num_classes, self.pos_label)
return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average)

@property
def is_differentiable(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):

def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore
precisions, recalls, _ = super().compute()
return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes)
return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes, average=None)


class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
Expand Down
Loading