diff --git a/CHANGELOG.md b/CHANGELOG.md index 5aedcd2ec69..2b0b0022476 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) +- update `InfoLM` class to dynamically set `higher_is_better` ([#2674](https://github.com/Lightning-AI/torchmetrics/pull/2674)) + + ### Removed - diff --git a/src/torchmetrics/text/infolm.py b/src/torchmetrics/text/infolm.py index b5c2de893f7..31fea4adc23 100644 --- a/src/torchmetrics/text/infolm.py +++ b/src/torchmetrics/text/infolm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -111,12 +111,25 @@ class InfoLM(Metric): """ is_differentiable = False - higher_is_better = True preds_input_ids: List[Tensor] preds_attention_mask: List[Tensor] target_input_ids: List[Tensor] target_attention_mask: List[Tensor] + _information_measure_higher_is_better: ClassVar = { + # following values are <0 + "kl_divergence": True, + "alpha_divergence": True, + # following values are >0 + "beta_divergence": False, + "ab_divergence": False, + "renyi_divergence": False, + "l1_distance": False, + "l2_distance": False, + "l_infinity_distance": False, + "fisher_rao_distance": False, + } + def __init__( self, model_name_or_path: Union[str, os.PathLike] = "bert-base-uncased", @@ -156,6 +169,15 @@ def __init__( self.add_state("target_input_ids", [], dist_reduce_fx="cat") self.add_state("target_attention_mask", [], dist_reduce_fx="cat") + @property + def higher_is_better(self) -> bool: # type: ignore[override] + """Returns a bool indicating whether a higher value of the information measure is better. + + Done this way as depends on if the information measure is positive or negative. + + """ + return self._information_measure_higher_is_better[self.information_measure] + def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[str]]) -> None: """Update state with predictions and targets.""" preds_input_ids, preds_attention_mask, target_input_ids, target_attention_mask = _infolm_update( diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index 1ee45cde02e..b3fd26026ca 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -182,3 +182,18 @@ def test_infolm_differentiability(self, preds, targets, information_measure, idf metric_functional=infolm, metric_args=metric_args, ) + + @skip_on_connection_issues() + def test_infolm_higher_is_better_property(self, preds, targets, information_measure, idf, alpha, beta): + """Test the `higher_is_better` property of the metric.""" + metric_args = { + "model_name_or_path": MODEL_NAME, + "information_measure": information_measure, + "idf": idf, + "alpha": alpha, + "beta": beta, + "max_length": MAX_LENGTH, + } + + metric = InfoLM(**metric_args) + assert metric.higher_is_better == metric._information_measure_higher_is_better[information_measure]