diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 825380d99bc..ed00ab1478a 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -23,7 +23,7 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: x = x if isinstance(x, (list, tuple)) else [x] - x = [xx.unsqueeze(0) if xx.numel() == 1 and xx.ndim == 0 else xx for xx in x] + x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] return torch.cat(x, dim=0)