From 33864db0c30df394301e1a22446abdb8d0f366a1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 May 2021 19:35:00 +0200 Subject: [PATCH] Fix concatenation of zero dim states (#229) * fix * changelog Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 +++ torchmetrics/functional/regression/spearman.py | 4 ++-- torchmetrics/utilities/data.py | 7 ++++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 862f9a3c4c4..8feedfaaf28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed metric calculation with unequal batch sizes ([#220](https://github.com/PyTorchLightning/metrics/pull/220)) +- Fixed metric concatenation for list states for zero-dim input ([#229](https://github.com/PyTorchLightning/metrics/pull/229)) + + ## [0.3.1] - 2021-04-21 - Cleaning remaining inconsistency and fix PL develop integration ( diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index fdfd56b7e16..ab97e150d96 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -59,10 +59,10 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Te f" Got preds: {preds.dtype} and target: {target.dtype}." ) _check_same_shape(preds, target) - + preds = preds.squeeze() + target = target.squeeze() if preds.ndim > 1 or target.ndim > 1: raise ValueError('Expected both predictions and target to be 1 dimensional tensors.') - return preds, target diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 002d06395cd..ed00ab1478a 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -21,16 +21,17 @@ METRIC_EPS = 1e-6 -def dim_zero_cat(x): +def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: x = x if isinstance(x, (list, tuple)) else [x] + x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] return torch.cat(x, dim=0) -def dim_zero_sum(x): +def dim_zero_sum(x: Tensor) -> Tensor: return torch.sum(x, dim=0) -def dim_zero_mean(x): +def dim_zero_mean(x: Tensor) -> Tensor: return torch.mean(x, dim=0)