Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/tracker higher is better integration #2649

Merged
merged 12 commits into from
Aug 6, 2024
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_TORCH_GREATER_EQUAL_2_0 = RequirementCache("torch>=2.0.0")
_TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0")
_TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0")
_TORCHMETRICS_GREATER_EQUAL_1_6 = RequirementCache("torchmetrics>=1.7.0")

_NLTK_AVAILABLE = RequirementCache("nltk")
_ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score")
Expand Down
45 changes: 37 additions & 8 deletions src/torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,50 @@ class MetricTracker(ModuleList):

"""

def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None:
def __init__(
self, metric: Union[Metric, MetricCollection], maximize: Optional[Union[bool, List[bool]]] = True
) -> None:
super().__init__()
if not isinstance(metric, (Metric, MetricCollection)):
raise TypeError(
"Metric arg need to be an instance of a torchmetrics"
f" `Metric` or `MetricCollection` but got {metric}"
)
self._base_metric = metric
if not isinstance(maximize, (bool, list)):
raise ValueError("Argument `maximize` should either be a single bool or list of bool")
if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric):
raise ValueError("The len of argument `maximize` should match the length of the metric collection")
if isinstance(metric, Metric) and not isinstance(maximize, bool):
raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric")
self.maximize = maximize

if maximize is None:
if isinstance(metric, Metric):
if not hasattr(metric, "higher_is_better"):
raise AttributeError(
f"The metric '{metric.__class__.__name__}' does not have a 'higher_is_better' attribute."
" Please provide the `maximize` argument explicitly."
)
self.maximize = metric.higher_is_better
elif isinstance(metric, MetricCollection):
self.maximize = []
for name, m in metric.items():
if not hasattr(m, "higher_is_better"):
raise AttributeError(
f"The metric '{name}' in the MetricCollection does not have a 'higher_is_better' attribute."
" Please provide the `maximize` argument explicitly."
)
self.maximize.append(m.higher_is_better)
else:
rank_zero_warn(
"The default value for `maximize` will be changed from `True` to `None` in v1.7.0 of TorchMetrics,"
"will automatically infer the value based on the `higher_is_better` attribute of the metric"
" (if such attribute exists) or raise an error if it does not. If you are explicitly setting the"
" `maximize` argument to either `True` or `False` already, you can ignore this warning.",
FutureWarning,
)

if not isinstance(maximize, (bool, list)):
raise ValueError("Argument `maximize` should either be a single bool or list of bool")
if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric):
raise ValueError("The len of argument `maximize` should match the length of the metric collection")
if isinstance(metric, Metric) and not isinstance(maximize, bool):
raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric")
self.maximize = maximize

self._increment_called = False

Expand Down
39 changes: 38 additions & 1 deletion tests/unittests/wrappers/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import pytest
import torch
from torchmetrics import MetricCollection
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import (
MulticlassAccuracy,
MulticlassConfusionMatrix,
MulticlassPrecision,
MulticlassRecall,
)
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_6
from torchmetrics.wrappers import MetricTracker, MultioutputWrapper

from unittests._helpers import seed_all
Expand Down Expand Up @@ -216,3 +219,37 @@ def test_metric_tracker_and_collection_multioutput(input_to_tracker, assert_type
else:
assert best_metric is None
assert which_epoch is None


def test_tracker_futurewarning():
"""Check that future warning is raised for the maximize argument.

Also to make sure that we remove it in future versions of TM.

"""
if _TORCHMETRICS_GREATER_EQUAL_1_6:
# Check that for future versions that we remove the warning
with warnings.catch_warnings():
warnings.simplefilter("error")
MetricTracker(MeanSquaredError(), maximize=True)
else:
with pytest.warns(FutureWarning, match="The default value for `maximize` will be changed from `True` to.*"):
MetricTracker(MeanSquaredError(), maximize=True)


@pytest.mark.parametrize(
"base_metric",
[
MeanSquaredError(),
MeanAbsoluteError(),
MulticlassAccuracy(num_classes=10),
MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
],
)
def test_tracker_higher_is_better_integration(base_metric):
"""Check that the maximize argument is correctly set based on the metric higher_is_better attribute."""
tracker = MetricTracker(base_metric, maximize=None)
if isinstance(base_metric, Metric):
assert tracker.maximize == base_metric.higher_is_better
else:
assert tracker.maximize == [m.higher_is_better for m in base_metric.values()]
Loading