Skip to content

Commit 41bb1fa

Browse files
SkafteNickiBordamergify[bot]
authored
Fix bugs in MAP related to iou_type="segm" (#1763)
* fix + tests * changelog * refactor test * fix * fix --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent cf7604d commit 41bb1fa

File tree

4 files changed

+112
-3
lines changed

4 files changed

+112
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
175175
- Fixed states being saved in metrics that use `register_buffer` ([#1728](https://github.com/Lightning-AI/torchmetrics/pull/1728))
176176

177177

178+
- Fixed states not being correctly synced and device transfered in `MeanAveragePrecision` for `iou_type="segm"` ([#1763](https://github.com/Lightning-AI/torchmetrics/pull/1763))
179+
178180
## [0.11.4] - 2023-03-10
179181

180182
### Fixed

src/torchmetrics/detection/mean_ap.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
15+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
1616

1717
import numpy as np
1818
import torch
19+
import torch.distributed as dist
1920
from torch import IntTensor, Tensor
2021

2122
from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
@@ -870,6 +871,48 @@ def compute(self) -> dict:
870871
metrics.classes = torch.tensor(classes, dtype=torch.int)
871872
return metrics
872873

874+
def _apply(self, fn: Callable) -> torch.nn.Module:
875+
"""Custom apply function.
876+
877+
Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is
878+
no longer a tensor but a tuple.
879+
"""
880+
if self.iou_type == "segm":
881+
this = super()._apply(fn, exclude_state=("detections", "groundtruths"))
882+
else:
883+
this = super()._apply(fn)
884+
return this
885+
886+
def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None:
887+
"""Custom sync function.
888+
889+
For the iou_type `segm` the detections and groundtruths are no longer tensors but tuples. Therefore, we need
890+
to gather the list of tuples and then convert it back to a list of tuples.
891+
892+
"""
893+
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group)
894+
895+
if self.iou_type == "segm":
896+
self.detections = self._gather_tuple_list(self.detections, process_group)
897+
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group)
898+
899+
@staticmethod
900+
def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]:
901+
"""Gather a list of tuples over multiple devices."""
902+
world_size = dist.get_world_size(group=process_group)
903+
list_gathered = [None] * world_size
904+
dist.all_gather_object(list_gathered, list_to_gather, group=process_group)
905+
906+
for rank in range(1, world_size):
907+
if list_gathered[rank] != list_gathered[0]:
908+
raise ValueError(f"Rank {rank} and Rank 0 have different values for the list to gather.")
909+
list_merged = []
910+
for idx in range(len(list_gathered[0])):
911+
for rank in range(world_size):
912+
list_merged.append(list_gathered[rank][idx])
913+
914+
return list_merged
915+
873916
def plot(
874917
self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None
875918
) -> _PLOT_OUT_TYPE:

src/torchmetrics/metric.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -703,11 +703,16 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
703703
out._dtype_convert = False
704704
return out
705705

706-
def _apply(self, fn: Callable) -> Module:
706+
def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
707707
"""Overwrite _apply function such that we can also move metric states to the correct device.
708708
709709
This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
710710
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
711+
712+
Args:
713+
fn: the function to apply
714+
exclude_state: list of state variables to exclude from applying the function, that then needs to be handled
715+
by the metric class itself.
711716
"""
712717
this = super()._apply(fn)
713718
fs = str(fn)
@@ -717,6 +722,9 @@ def _apply(self, fn: Callable) -> Module:
717722

718723
# Also apply fn to metric states and defaults
719724
for key, value in this._defaults.items():
725+
if key in exclude_state:
726+
continue
727+
720728
if isinstance(value, Tensor):
721729
this._defaults[key] = fn(value)
722730
elif isinstance(value, Sequence):

tests/unittests/detection/test_map.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ def _create_inputs_masks() -> Input:
5050
"labels": IntTensor([3, 2]),
5151
}, # 73
5252
],
53+
[
54+
{
55+
"masks": _mask_unsqueeze_bool(inputs_json["preds"][0]),
56+
"scores": Tensor([0.236]),
57+
"labels": IntTensor([4]),
58+
},
59+
{
60+
"masks": _masks_stack_bool([inputs_json["preds"][1], inputs_json["preds"][2]]),
61+
"scores": Tensor([0.318, 0.726]),
62+
"labels": IntTensor([3, 2]),
63+
}, # 73
64+
],
5365
],
5466
target=[
5567
[
@@ -59,6 +71,13 @@ def _create_inputs_masks() -> Input:
5971
"labels": IntTensor([2, 2]),
6072
}, # 73
6173
],
74+
[
75+
{"masks": _mask_unsqueeze_bool(inputs_json["targets"][0]), "labels": IntTensor([4])}, # 42
76+
{
77+
"masks": _masks_stack_bool([inputs_json["targets"][1], inputs_json["targets"][2]]),
78+
"labels": IntTensor([2, 2]),
79+
}, # 73
80+
],
6281
],
6382
)
6483

@@ -357,7 +376,7 @@ def test_map_bbox(self, compute_on_cpu, ddp):
357376
metric_args={"class_metrics": True, "compute_on_cpu": compute_on_cpu},
358377
)
359378

360-
@pytest.mark.parametrize("ddp", [False])
379+
@pytest.mark.parametrize("ddp", [False, True])
361380
def test_map_segm(self, compute_on_cpu, ddp):
362381
"""Test modular implementation for correctness."""
363382
_inputs_masks = _create_inputs_masks()
@@ -660,3 +679,40 @@ def test_error_on_wrong_input():
660679
[{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}],
661680
[{"boxes": Tensor(), "labels": []}],
662681
)
682+
683+
684+
def _generate_random_segm_input(device):
685+
"""Generate random inputs for mAP when iou_type=segm."""
686+
preds = []
687+
targets = []
688+
for _ in range(2):
689+
result = {}
690+
num_preds = torch.randint(0, 10, (1,)).item()
691+
result["scores"] = torch.rand((num_preds,), device=device)
692+
result["labels"] = torch.randint(0, 10, (num_preds,), device=device)
693+
result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool()
694+
preds.append(result)
695+
gt = {}
696+
num_gt = torch.randint(0, 10, (1,)).item()
697+
gt["labels"] = torch.randint(0, 10, (num_gt,), device=device)
698+
gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool()
699+
targets.append(gt)
700+
return preds, targets
701+
702+
703+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
704+
def test_device_changing():
705+
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1743.
706+
707+
Checks that the custom apply function of the metric works as expected.
708+
"""
709+
device = "cuda"
710+
metric = MeanAveragePrecision(iou_type="segm").to(device)
711+
712+
for _ in range(2):
713+
preds, targets = _generate_random_segm_input(device)
714+
metric.update(preds, targets)
715+
716+
metric = metric.cpu()
717+
val = metric.compute()
718+
assert isinstance(val, dict)

0 commit comments

Comments
 (0)