diff --git a/CHANGELOG.md b/CHANGELOG.md index 317f93a6bff..c2d6f007cf6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499)) +- Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437)) + + +- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) + + + ### Changed +- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) + ### Deprecated @@ -28,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495)) + ## [0.5.1] - 2021-08-30 ### Added diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index 6c5fe7051a4..98dc4c23baf 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -30,7 +30,7 @@ seed_all(42) -def _sk_average_precision_score(y_true, probas_pred, num_classes=1): +def _sk_average_precision_score(y_true, probas_pred, num_classes=1, average=None): if num_classes == 1: return sk_average_precision_score(y_true, probas_pred) @@ -39,33 +39,41 @@ def _sk_average_precision_score(y_true, probas_pred, num_classes=1): y_true_temp = np.zeros_like(y_true) y_true_temp[y_true == i] = 1 res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) + + if average == "macro": + return np.array(res).mean() + elif average == "weighted": + weights = np.bincount(y_true) if y_true.max() > 1 else y_true.sum(axis=0) + weights = weights / sum(weights) + return (np.array(res) * weights).sum() + return res -def _sk_avg_prec_binary_prob(preds, target, num_classes=1): +def _sk_avg_prec_binary_prob(preds, target, num_classes=1, average=None): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) -def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1): +def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1, average=None): sk_preds = preds.reshape(-1, num_classes).numpy() sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) -def _sk_avg_prec_multilabel_prob(preds, target, num_classes): +def _sk_avg_prec_multilabel_prob(preds, target, num_classes=1, average=None): sk_preds = preds.reshape(-1, num_classes).numpy() sk_target = target.view(-1, num_classes).numpy() - return sk_average_precision_score(sk_target, sk_preds, average=None) + return sk_average_precision_score(sk_target, sk_preds, average=average) -def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): +def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=None): sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average) @pytest.mark.parametrize( @@ -77,30 +85,37 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): (_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES), ], ) +@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) class TestAveragePrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step): + if target.max() > 1 and average == "micro": + pytest.skip("average=micro and multiclass input cannot be used together") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=AveragePrecision, - sk_metric=partial(sk_metric, num_classes=num_classes), + sk_metric=partial(sk_metric, num_classes=num_classes, average=average), dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes}, + metric_args={"num_classes": num_classes, "average": average}, ) - def test_average_precision_functional(self, preds, target, sk_metric, num_classes): + def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average): + if target.max() > 1 and average == "micro": + pytest.skip("average=micro and multiclass input cannot be used together") + self.run_functional_metric_test( preds=preds, target=target, metric_functional=average_precision, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, + sk_metric=partial(sk_metric, num_classes=num_classes, average=average), + metric_args={"num_classes": num_classes, "average": average}, ) - def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes): + def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes, average): self.run_differentiability_test( preds=preds, target=target, @@ -126,3 +141,30 @@ def test_average_precision_differentiability(self, preds, sk_metric, target, num ) def test_average_precision(scores, target, expected_score): assert average_precision(scores, target) == expected_score + + +def test_average_precision_warnings_and_errors(): + """Test that the correct errors and warnings gets raised.""" + + # check average argument + with pytest.raises(ValueError, match="Expected argument `average` to be one .*"): + AveragePrecision(num_classes=5, average="samples") + + # check that micro average cannot be used with multilabel input + pred = tensor( + [ + [0.75, 0.05, 0.05, 0.05, 0.05], + [0.05, 0.75, 0.05, 0.05, 0.05], + [0.05, 0.05, 0.75, 0.05, 0.05], + [0.05, 0.05, 0.05, 0.75, 0.05], + ] + ) + target = tensor([0, 1, 3, 2]) + average_precision = AveragePrecision(num_classes=5, average="micro") + with pytest.raises(ValueError, match="Cannot use `micro` average with multi-class input"): + average_precision(pred, target) + + # check that warning is thrown when average=macro and nan is encoutered in individual scores + average_precision = AveragePrecision(num_classes=5, average="macro") + with pytest.warns(UserWarning, match="Average precision score for one or more classes was `nan`.*"): + average_precision(pred, target) diff --git a/torchmetrics/classification/average_precision.py b/torchmetrics/classification/average_precision.py index 06775675f43..94affd6b98c 100644 --- a/torchmetrics/classification/average_precision.py +++ b/torchmetrics/classification/average_precision.py @@ -44,6 +44,19 @@ class AveragePrecision(Metric): which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1] + average: + defines the reduction that is applied in the case of multiclass and multilabel input. + Should be one of the following: + + - ``'macro'`` [default]: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be + used with multiclass input. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support. + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -66,7 +79,7 @@ class AveragePrecision(Metric): ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision = AveragePrecision(num_classes=5) + >>> average_precision = AveragePrecision(num_classes=5, average=None) >>> average_precision(pred, target) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ @@ -78,6 +91,7 @@ def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, + average: Optional[str] = "macro", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -90,6 +104,10 @@ def __init__( self.num_classes = num_classes self.pos_label = pos_label + allowed_average = ("micro", "macro", "weighted", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") + self.average = average self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") @@ -107,7 +125,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore target: Ground truth values """ preds, target, num_classes, pos_label = _average_precision_update( - preds, target, self.num_classes, self.pos_label + preds, target, self.num_classes, self.pos_label, self.average ) self.preds.append(preds) self.target.append(target) @@ -125,7 +143,7 @@ def compute(self) -> Union[Tensor, List[Tensor]]: target = dim_zero_cat(self.target) if not self.num_classes: raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") - return _average_precision_compute(preds, target, self.num_classes, self.pos_label) + return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average) @property def is_differentiable(self) -> bool: diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 654fd851d1c..4c9b25dcb18 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -239,7 +239,7 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore precisions, recalls, _ = super().compute() - return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes) + return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes, average=None) class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): diff --git a/torchmetrics/functional/classification/average_precision.py b/torchmetrics/functional/classification/average_precision.py index bf32199fb8a..16858469ecd 100644 --- a/torchmetrics/functional/classification/average_precision.py +++ b/torchmetrics/functional/classification/average_precision.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import List, Optional, Sequence, Tuple, Union import torch @@ -27,8 +28,30 @@ def _average_precision_update( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, + average: Optional[str] = "macro", ) -> Tuple[Tensor, Tensor, int, Optional[int]]: - return _precision_recall_curve_update(preds, target, num_classes, pos_label) + """Format the predictions and target based on the ``num_classes``, ``pos_label`` and ``average`` parameter + Args: + preds: predictions from model (logits or probabilities) + target: ground truth values + num_classes: integer with number of classes. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + average: reduction method for multi-class or multi-label problems + """ + preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) + if average == "micro": + if preds.ndim == target.ndim: + # Considering each element of the label indicator matrix as a label + preds = preds.flatten() + target = target.flatten() + num_classes = 1 + else: + raise ValueError("Cannot use `micro` average with multi-class input") + + return preds, target, num_classes, pos_label def _average_precision_compute( @@ -36,6 +59,7 @@ def _average_precision_compute( target: Tensor, num_classes: int, pos_label: Optional[int] = None, + average: Optional[str] = "macro", sample_weights: Optional[Sequence] = None, ) -> Union[List[Tensor], Tensor]: """Computes the average precision score. @@ -48,6 +72,7 @@ def _average_precision_compute( which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1] + average: reduction method for multi-class or multi-label problems sample_weights: sample weights for each data point Example: @@ -67,19 +92,29 @@ def _average_precision_compute( >>> target = torch.tensor([0, 1, 3, 2]) >>> num_classes = 5 >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) - >>> _average_precision_compute(preds, target, num_classes) + >>> _average_precision_compute(preds, target, num_classes, average=None) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ # todo: `sample_weights` is unused precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) - return _average_precision_compute_with_precision_recall(precision, recall, num_classes) + if average == "weighted": + if preds.ndim == target.ndim and target.ndim > 1: + weights = target.sum(dim=0).float() + else: + weights = torch.bincount(target, minlength=num_classes).float() + weights = weights / torch.sum(weights) + else: + weights = None + return _average_precision_compute_with_precision_recall(precision, recall, num_classes, average, weights) def _average_precision_compute_with_precision_recall( precision: Tensor, recall: Tensor, num_classes: int, + average: Optional[str] = "macro", + weights: Optional[Tensor] = None, ) -> Union[List[Tensor], Tensor]: """Computes the average precision score from precision and recall. @@ -88,6 +123,8 @@ def _average_precision_compute_with_precision_recall( recall: recall values num_classes: integer with number of classes. Not nessesary to provide for binary problems. + average: reduction method for multi-class or multi-label problems + weights: weights to use when average='weighted' Example: >>> # binary case @@ -96,7 +133,7 @@ def _average_precision_compute_with_precision_recall( >>> pos_label = 1 >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label) >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) - >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes) + >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None) tensor(1.) >>> # multiclass case @@ -108,7 +145,7 @@ def _average_precision_compute_with_precision_recall( >>> num_classes = 5 >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes) - >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes) + >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ @@ -121,7 +158,23 @@ def _average_precision_compute_with_precision_recall( res = [] for p, r in zip(precision, recall): res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) - return res + + # Reduce + if average in ("macro", "weighted"): + res = torch.stack(res) + if torch.isnan(res).any(): + warnings.warn( + "Average precision score for one or more classes was `nan`. Ignoring these classes in average", + UserWarning, + ) + if average == "macro": + return res[~torch.isnan(res)].mean() + weights = torch.ones_like(res) if weights is None else weights + return (res * weights)[~torch.isnan(res)].sum() + elif average is None: + return res + allowed_average = ("micro", "macro", "weighted", None) + raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") def average_precision( @@ -129,6 +182,7 @@ def average_precision( target: Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, + average: Optional[str] = "macro", sample_weights: Optional[Sequence] = None, ) -> Union[List[Tensor], Tensor]: """Computes the average precision score. @@ -142,6 +196,19 @@ def average_precision( which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1] + average: + defines the reduction that is applied in the case of multiclass and multilabel input. + Should be one of the following: + + - ``'macro'`` [default]: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be + used with multiclass input. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support. + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + sample_weights: sample weights for each data point Returns: @@ -161,9 +228,9 @@ def average_precision( ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision(pred, target, num_classes=5) + >>> average_precision(pred, target, num_classes=5, average=None) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ # fixme: `sample_weights` is unused - preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label) - return _average_precision_compute(preds, target, num_classes, pos_label, sample_weights) + preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label, average) + return _average_precision_compute(preds, target, num_classes, pos_label, average, sample_weights)