Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GeneralizedDiceScore fails with 2D tensors and input_format="index" #2816

Closed
fguiotte opened this issue Nov 1, 2024 · 1 comment · Fixed by #2832
Closed

GeneralizedDiceScore fails with 2D tensors and input_format="index" #2816

fguiotte opened this issue Nov 1, 2024 · 1 comment · Fixed by #2832
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.5.x

Comments

@fguiotte
Copy link

fguiotte commented Nov 1, 2024

🐛 Bug

The generalized dice score does not work with 2D tensors of shape (N, L) and input_format="index".

To Reproduce

import torch
from torchmetrics.segmentation import GeneralizedDiceScore

batch_size = 16
num_classes = 8
L = 100

y = torch.randint(num_classes, (batch_size, L))
pred = torch.randint(num_classes, (batch_size, L))

dice = GeneralizedDiceScore(num_classes=num_classes, input_format="index")

dice(y, pred)
ValueError: Expected both `preds` and `target` to have at least 3 dimensions, but got 2.

Expected behavior

According to the documentation, the dice score should work on tensors of shape (N, ...) with input_format="index":

- ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
- ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.

Environment

  • TorchMetrics version 1.5.1

Additional context

The exception is coming from dice functional:

if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

@fguiotte fguiotte added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 1, 2024
Copy link

github-actions bot commented Nov 1, 2024

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.5.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants