Skip to content

Commit

Permalink
Feature/tracker higher is better integration (#2649)
Browse files Browse the repository at this point in the history
* implementation
* add test cases
* chlog

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: jirka <[email protected]>
  • Loading branch information
3 people authored Aug 6, 2024
1 parent f697f35 commit 589916c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))


### Removed
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_TORCH_GREATER_EQUAL_2_0 = RequirementCache("torch>=2.0.0")
_TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0")
_TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0")
_TORCHMETRICS_GREATER_EQUAL_1_6 = RequirementCache("torchmetrics>=1.7.0")

_NLTK_AVAILABLE = RequirementCache("nltk")
_ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score")
Expand Down
49 changes: 41 additions & 8 deletions src/torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,54 @@ class MetricTracker(ModuleList):
"""

def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None:
maximize: Union[bool, List[bool]]

def __init__(
self, metric: Union[Metric, MetricCollection], maximize: Optional[Union[bool, List[bool]]] = True
) -> None:
super().__init__()
if not isinstance(metric, (Metric, MetricCollection)):
raise TypeError(
"Metric arg need to be an instance of a torchmetrics"
f" `Metric` or `MetricCollection` but got {metric}"
)
self._base_metric = metric
if not isinstance(maximize, (bool, list)):
raise ValueError("Argument `maximize` should either be a single bool or list of bool")
if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric):
raise ValueError("The len of argument `maximize` should match the length of the metric collection")
if isinstance(metric, Metric) and not isinstance(maximize, bool):
raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric")
self.maximize = maximize

if maximize is None:
if isinstance(metric, Metric):
if getattr(metric, "higher_is_better", None) is None:
raise AttributeError(
f"The metric '{metric.__class__.__name__}' does not have a 'higher_is_better' attribute."
" Please provide the `maximize` argument explicitly."
)
self.maximize = metric.higher_is_better # type: ignore[assignment] # this is false alarm
elif isinstance(metric, MetricCollection):
self.maximize = []
for name, m in metric.items():
if getattr(m, "higher_is_better", None) is None:
raise AttributeError(
f"The metric '{name}' in the MetricCollection does not have a 'higher_is_better' attribute."
" Please provide the `maximize` argument explicitly."
)
self.maximize.append(m.higher_is_better) # type: ignore[arg-type] # this is false alarm
else:
rank_zero_warn(
"The default value for `maximize` will be changed from `True` to `None` in v1.7.0 of TorchMetrics,"
"will automatically infer the value based on the `higher_is_better` attribute of the metric"
" (if such attribute exists) or raise an error if it does not. If you are explicitly setting the"
" `maximize` argument to either `True` or `False` already, you can ignore this warning.",
FutureWarning,
)

if not isinstance(maximize, (bool, list)):
raise ValueError("Argument `maximize` should either be a single bool or list of bool")
if isinstance(maximize, list) and not all(isinstance(m, bool) for m in maximize):
raise ValueError("Argument `maximize` is list but not type of bool.")
if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric):
raise ValueError("The len of argument `maximize` should match the length of the metric collection")
if isinstance(metric, Metric) and not isinstance(maximize, bool):
raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric")
self.maximize = maximize

self._increment_called = False

Expand Down
39 changes: 38 additions & 1 deletion tests/unittests/wrappers/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import pytest
import torch
from torchmetrics import MetricCollection
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import (
MulticlassAccuracy,
MulticlassConfusionMatrix,
MulticlassPrecision,
MulticlassRecall,
)
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_6
from torchmetrics.wrappers import MetricTracker, MultioutputWrapper

from unittests._helpers import seed_all
Expand Down Expand Up @@ -216,3 +219,37 @@ def test_metric_tracker_and_collection_multioutput(input_to_tracker, assert_type
else:
assert best_metric is None
assert which_epoch is None


def test_tracker_futurewarning():
"""Check that future warning is raised for the maximize argument.
Also to make sure that we remove it in future versions of TM.
"""
if _TORCHMETRICS_GREATER_EQUAL_1_6:
# Check that for future versions that we remove the warning
with warnings.catch_warnings():
warnings.simplefilter("error")
MetricTracker(MeanSquaredError(), maximize=True)
else:
with pytest.warns(FutureWarning, match="The default value for `maximize` will be changed from `True` to.*"):
MetricTracker(MeanSquaredError(), maximize=True)


@pytest.mark.parametrize(
"base_metric",
[
MeanSquaredError(),
MeanAbsoluteError(),
MulticlassAccuracy(num_classes=10),
MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
],
)
def test_tracker_higher_is_better_integration(base_metric):
"""Check that the maximize argument is correctly set based on the metric higher_is_better attribute."""
tracker = MetricTracker(base_metric, maximize=None)
if isinstance(base_metric, Metric):
assert tracker.maximize == base_metric.higher_is_better
else:
assert tracker.maximize == [m.higher_is_better for m in base_metric.values()]

0 comments on commit 589916c

Please sign in to comment.