Skip to content

Commit 68ea8d0

Browse files
authored
Merge branch 'master' into feature/multioutput_mse
2 parents 28b2da7 + 51b9047 commit 68ea8d0

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4141

4242
### Fixed
4343

44+
- Fix support for int input for when `multidim_average="samplewise"` in classification metrics ([#1977](https://github.com/Lightning-AI/torchmetrics/pull/1977))
45+
46+
4447
- Fixed x/y labels when plotting confusion matrices ([#1976](https://github.com/Lightning-AI/torchmetrics/pull/1976))
4548

4649

src/torchmetrics/functional/classification/stat_scores.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,10 @@ def _multiclass_stat_scores_update(
374374
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
375375
else:
376376
preds_oh = torch.nn.functional.one_hot(
377-
preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
377+
preds.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
378378
)
379379
target_oh = torch.nn.functional.one_hot(
380-
target, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
380+
target.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
381381
)
382382
if ignore_index is not None:
383383
if 0 <= ignore_index <= num_classes - 1:

tests/unittests/classification/test_stat_scores.py

+9
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,15 @@ def test_multilabel_stat_scores_dtype_gpu(self, inputs, dtype):
542542
)
543543

544544

545+
def test_support_for_int():
546+
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970."""
547+
metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0)
548+
prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
549+
label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
550+
score = metric(preds=prediction, target=label)
551+
assert score.shape == (1, 4, 5)
552+
553+
545554
@pytest.mark.parametrize(
546555
("metric", "kwargs"),
547556
[

0 commit comments

Comments
 (0)