Skip to content

Commit

Permalink
Fix plot splitter (#2060)
Browse files Browse the repository at this point in the history
* fix + tests

* changelog
  • Loading branch information
SkafteNicki authored Sep 6, 2023
1 parent ca9fe3d commit acaf4cc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))
-


- Fixed bug when creating multiple plots that lead to not all plots being shown ([#2060](https://github.com/Lightning-AI/torchmetrics/pull/2060))


## [1.1.1] - 2023-08-29
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def plot_single_or_multi_val(
def _get_col_row_split(n: int) -> Tuple[int, int]:
"""Split `n` figures into `rows` x `cols` figures."""
nsq = sqrt(n)
if nsq * nsq == n:
if int(nsq) == nsq: # square number
return int(nsq), int(nsq)
if floor(nsq) * ceil(nsq) > n:
if floor(nsq) * ceil(nsq) >= n:
return floor(nsq), ceil(nsq)
return ceil(nsq), ceil(nsq)

Expand Down
12 changes: 12 additions & 0 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
_TORCH_GREATER_EQUAL_1_10,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)
from torchmetrics.utilities.plot import _get_col_row_split
from torchmetrics.wrappers import (
BootStrapper,
ClasswiseWrapper,
Expand Down Expand Up @@ -794,6 +795,17 @@ def test_plot_methods_retrieval(metric_class, preds, target, indexes, num_vals):
plt.close(fig)


@pytest.mark.parametrize(
("n", "expected_row", "expected_col"),
[(1, 1, 1), (2, 1, 2), (3, 2, 2), (4, 2, 2), (5, 2, 3), (6, 2, 3), (7, 3, 3), (8, 3, 3), (9, 3, 3), (10, 3, 4)],
)
def test_row_col_splitter(n, expected_row, expected_col):
"""Test the row col splitter function works as expected."""
row, col = _get_col_row_split(n)
assert row == expected_row
assert col == expected_col


@pytest.mark.parametrize(
("metric_class", "preds", "target", "labels"),
[
Expand Down

0 comments on commit acaf4cc

Please sign in to comment.