From 4c77453ea5eeca614a2284539d83cc565da5d918 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 10 Sep 2024 20:15:57 +0200 Subject: [PATCH] Fix how `prefix`/`posfix` works in `MultitaskWrapper` (#2722) * implementation * tests * changelog * fix mypy (cherry picked from commit eecc55bc59395f66a4d50eb1e359ba38d59da30e) --- CHANGELOG.md | 3 + src/torchmetrics/wrappers/multitask.py | 102 ++++++++++++++------- tests/unittests/wrappers/test_multitask.py | 28 ++++-- 3 files changed, 93 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bea26412357..f16920d11cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721)) +- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722)) + + ## [1.4.1] - 2024-08-02 ### Changed diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index 556a7183638..fa2f04db97d 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -38,12 +38,27 @@ class MultitaskWrapper(WrapperMetric): task_metrics: Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the names of the tasks, and the values represent the metrics to use for each task. + prefix: + A string to append in front of the metric keys. If not provided, will default to an empty string. + postfix: + A string to append after the keys of the output dict. If not provided, will default to an empty string. + + .. note:: + The use pre prefix and postfix allows for easily creating task wrappers for training, validation and test. + The arguments are only changing the output keys of the computed metrics and not the input keys. This means + that a ``MultitaskWrapper`` initialized as ``MultitaskWrapper({"task": Metric()}, prefix="train_")`` will + still expect the input to be a dictionary with the key "task", but the output will be a dictionary with the key + "train_task". Raises: TypeError: If argument `task_metrics` is not an dictionary TypeError: If not all values in the `task_metrics` dictionary is instances of `Metric` or `MetricCollection` + ValueError: + If `prefix` is not a string + ValueError: + If `postfix` is not a string Example (with a single metric per class): >>> import torch @@ -91,18 +106,59 @@ class MultitaskWrapper(WrapperMetric): {'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)}, 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}} + Example (with a prefix and postfix): + >>> import torch + >>> from torchmetrics.wrappers import MultitaskWrapper + >>> from torchmetrics.regression import MeanSquaredError + >>> from torchmetrics.classification import BinaryAccuracy + >>> + >>> classification_target = torch.tensor([0, 1, 0]) + >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) + >>> targets = {"Classification": classification_target, "Regression": regression_target} + >>> classification_preds = torch.tensor([0, 0, 1]) + >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) + >>> preds = {"Classification": classification_preds, "Regression": regression_preds} + >>> + >>> metrics = MultitaskWrapper({ + ... "Classification": BinaryAccuracy(), + ... "Regression": MeanSquaredError() + ... }, prefix="train_") + >>> metrics.update(preds, targets) + >>> metrics.compute() + {'train_Classification': tensor(0.3333), 'train_Regression': tensor(0.8333)} + """ - is_differentiable = False + is_differentiable: bool = False def __init__( self, task_metrics: Dict[str, Union[Metric, MetricCollection]], + prefix: Optional[str] = None, + postfix: Optional[str] = None, ) -> None: - self._check_task_metrics_type(task_metrics) super().__init__() + + if not isinstance(task_metrics, dict): + raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}") + + for metric in task_metrics.values(): + if not (isinstance(metric, (Metric, MetricCollection))): + raise TypeError( + "Expected each task's metric to be a Metric or a MetricCollection. " + f"Found a metric of type {type(metric)}" + ) + self.task_metrics = nn.ModuleDict(task_metrics) + if prefix is not None and not isinstance(prefix, str): + raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}") + self._prefix = prefix or "" + + if postfix is not None and not isinstance(postfix, str): + raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}") + self._postfix = postfix or "" + def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]: """Iterate over task and task metrics. @@ -114,9 +170,9 @@ def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]: for task_name, metric in self.task_metrics.items(): if flatten and isinstance(metric, MetricCollection): for sub_metric_name, sub_metric in metric.items(): - yield f"{task_name}_{sub_metric_name}", sub_metric + yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}", sub_metric else: - yield task_name, metric + yield f"{self._prefix}{task_name}{self._postfix}", metric def keys(self, flatten: bool = True) -> Iterable[str]: """Iterate over task names. @@ -129,9 +185,9 @@ def keys(self, flatten: bool = True) -> Iterable[str]: for task_name, metric in self.task_metrics.items(): if flatten and isinstance(metric, MetricCollection): for sub_metric_name in metric: - yield f"{task_name}_{sub_metric_name}" + yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}" else: - yield task_name + yield f"{self._prefix}{task_name}{self._postfix}" def values(self, flatten: bool = True) -> Iterable[nn.Module]: """Iterate over task metrics. @@ -147,18 +203,6 @@ def values(self, flatten: bool = True) -> Iterable[nn.Module]: else: yield metric - @staticmethod - def _check_task_metrics_type(task_metrics: Dict[str, Union[Metric, MetricCollection]]) -> None: - if not isinstance(task_metrics, dict): - raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}") - - for metric in task_metrics.values(): - if not (isinstance(metric, (Metric, MetricCollection))): - raise TypeError( - "Expected each task's metric to be a Metric or a MetricCollection. " - f"Found a metric of type {type(metric)}" - ) - def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> None: """Update each task's metric with its corresponding pred and target. @@ -179,9 +223,13 @@ def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> No target = task_targets[task_name] metric.update(pred, target) + def _convert_output(self, output: Dict[str, Any]) -> Dict[str, Any]: + """Convert the output of the underlying metrics to a dictionary with the task names as keys.""" + return {f"{self._prefix}{task_name}{self._postfix}": task_output for task_name, task_output in output.items()} + def compute(self) -> Dict[str, Any]: """Compute metrics for all tasks.""" - return {task_name: metric.compute() for task_name, metric in self.task_metrics.items()} + return self._convert_output({task_name: metric.compute() for task_name, metric in self.task_metrics.items()}) def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]: """Call underlying forward methods for all tasks and return the result as a dictionary.""" @@ -189,10 +237,10 @@ def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor] # value of full_state_update, and that also accumulates the results. Here, all computations are handled by the # underlying metrics, which all have their own value of full_state_update, and which all accumulate the results # by themselves. - return { + return self._convert_output({ task_name: metric(task_preds[task_name], task_targets[task_name]) for task_name, metric in self.task_metrics.items() - } + }) def reset(self) -> None: """Reset all underlying metrics.""" @@ -215,16 +263,8 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> """ multitask_copy = deepcopy(self) - if prefix is not None: - prefix = self._check_arg(prefix, "prefix") - multitask_copy.task_metrics = nn.ModuleDict({ - prefix + key: value for key, value in multitask_copy.task_metrics.items() - }) - if postfix is not None: - postfix = self._check_arg(postfix, "postfix") - multitask_copy.task_metrics = nn.ModuleDict({ - key + postfix: value for key, value in multitask_copy.task_metrics.items() - }) + multitask_copy._prefix = self._check_arg(prefix, "prefix") or "" + multitask_copy._postfix = self._check_arg(postfix, "prefix") or "" return multitask_copy def plot( diff --git a/tests/unittests/wrappers/test_multitask.py b/tests/unittests/wrappers/test_multitask.py index 63af6f31b35..069a4472d64 100644 --- a/tests/unittests/wrappers/test_multitask.py +++ b/tests/unittests/wrappers/test_multitask.py @@ -248,14 +248,24 @@ def test_key_value_items_method(method, flatten): def test_clone_with_prefix_and_postfix(): """Check that the clone method works with prefix and postfix arguments.""" - multitask_metrics = MultitaskWrapper({"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()}) - cloned_metrics_with_prefix = multitask_metrics.clone(prefix="prefix_") - cloned_metrics_with_postfix = multitask_metrics.clone(postfix="_postfix") + multitask_metrics = MultitaskWrapper( + {"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()}, + prefix="prefix_", + postfix="_postfix", + ) + assert set(multitask_metrics.keys()) == {"prefix_Classification_postfix", "prefix_Regression_postfix"} - # Check if the cloned metrics have the expected keys - assert set(cloned_metrics_with_prefix.task_metrics.keys()) == {"prefix_Classification", "prefix_Regression"} - assert set(cloned_metrics_with_postfix.task_metrics.keys()) == {"Classification_postfix", "Regression_postfix"} + output = multitask_metrics( + {"Classification": _classification_preds, "Regression": _regression_preds}, + {"Classification": _classification_target, "Regression": _regression_target}, + ) + assert set(output.keys()) == {"prefix_Classification_postfix", "prefix_Regression_postfix"} - # Check if the cloned metrics have the expected values - assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Classification"], BinaryAccuracy) - assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Regression"], MeanSquaredError) + cloned_metrics = multitask_metrics.clone(prefix="new_prefix_", postfix="_new_postfix") + assert set(cloned_metrics.keys()) == {"new_prefix_Classification_new_postfix", "new_prefix_Regression_new_postfix"} + + output = cloned_metrics( + {"Classification": _classification_preds, "Regression": _regression_preds}, + {"Classification": _classification_target, "Regression": _regression_target}, + ) + assert set(output.keys()) == {"new_prefix_Classification_new_postfix", "new_prefix_Regression_new_postfix"}