Skip to content

Commit

Permalink
Initialize aggregation metrics with default floating type (#2366)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Feb 12, 2024
1 parent 4527aaf commit b6f6e07
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348))


- Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366))

---

## [1.3.0] - 2024-01-10
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(
) -> None:
super().__init__(
"max",
-torch.tensor(float("inf")),
-torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
nan_strategy,
state_name="max_value",
**kwargs,
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(
) -> None:
super().__init__(
"min",
torch.tensor(float("inf")),
torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
nan_strategy,
state_name="min_value",
**kwargs,
Expand Down Expand Up @@ -366,7 +366,7 @@ def __init__(
) -> None:
super().__init__(
"sum",
torch.tensor(0.0),
torch.tensor(0.0, dtype=torch.get_default_dtype()),
nan_strategy,
state_name="sum_value",
**kwargs,
Expand Down Expand Up @@ -536,12 +536,12 @@ def __init__(
) -> None:
super().__init__(
"sum",
torch.tensor(0.0),
torch.tensor(0.0, dtype=torch.get_default_dtype()),
nan_strategy,
state_name="mean_value",
**kwargs,
)
self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum")

def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None:
"""Update state with data.
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,20 @@ def test_mean_metric_broadcast(nan_strategy):
metric.update(x, w)
res = metric.compute()
assert round(res.item(), 4) == 3.2222 # (0*0 + 2*2 + 3*3 + 4*4) / (0 + 2 + 3 + 4)


@pytest.mark.parametrize(
("metric_class", "compare_function"),
[(MinMetric, torch.min), (MaxMetric, torch.max), (SumMetric, torch.sum), (MeanMetric, torch.mean)],
)
def test_with_default_dtype(metric_class, compare_function):
"""Test that the metric works with a default dtype of float64."""
torch.set_default_dtype(torch.float64)
metric = metric_class()
values = torch.randn(10000)
metric.update(values)
result = metric.compute()
assert result.dtype == torch.float64
assert result.dtype == values.dtype
assert torch.allclose(result, compare_function(values), atol=1e-12)
torch.set_default_dtype(torch.float32)

0 comments on commit b6f6e07

Please sign in to comment.