From 0c8e1ec2b09ba4383694bfa9c2b6308a946b8fab Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 19 Jul 2023 14:23:22 +0200 Subject: [PATCH 01/12] add code --- src/torchmetrics/detection/helpers.py | 69 ++++-- src/torchmetrics/detection/mean_ap.py | 309 ++++++++++++++++---------- 2 files changed, 238 insertions(+), 140 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index c86787992f3..5f90991bcd0 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -11,20 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Sequence +from typing import Dict, List, Literal, Sequence, Union from torch import Tensor def _input_validator( - preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], iou_type: str = "bbox" + preds: Sequence[Dict[str, Tensor]], + targets: Sequence[Dict[str, Tensor]], + iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", ) -> None: """Ensure the correct input format of `preds` and `targets`.""" - if iou_type == "bbox": - item_val_name = "boxes" - elif iou_type == "segm": - item_val_name = "masks" - else: + if not isinstance(iou_type, list): + iou_type = [iou_type] + + item_val_name = [] + if "bbox" in iou_type: + item_val_name.append("boxes") + if "segm" in iou_type: + item_val_name.append("masks") + if any(i not in ["bbox", "segm"] for i in iou_type): raise Exception(f"IOU type {iou_type} is not supported") if not isinstance(preds, Sequence): @@ -36,38 +42,42 @@ def _input_validator( f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}" ) - for k in [item_val_name, "scores", "labels"]: + for k in [*item_val_name, "scores", "labels"]: if any(k not in p for p in preds): raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") - for k in [item_val_name, "labels"]: + for k in [*item_val_name, "labels"]: if any(k not in p for p in targets): raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key") - if any(type(pred[item_val_name]) is not Tensor for pred in preds): - raise ValueError(f"Expected all {item_val_name} in `preds` to be of type Tensor") + for ivn in item_val_name: + if any(type(pred[ivn]) is not Tensor for pred in preds): + raise ValueError(f"Expected all {ivn} in `preds` to be of type Tensor") if any(type(pred["scores"]) is not Tensor for pred in preds): raise ValueError("Expected all scores in `preds` to be of type Tensor") if any(type(pred["labels"]) is not Tensor for pred in preds): raise ValueError("Expected all labels in `preds` to be of type Tensor") - if any(type(target[item_val_name]) is not Tensor for target in targets): - raise ValueError(f"Expected all {item_val_name} in `target` to be of type Tensor") + for ivn in item_val_name: + if any(type(target[ivn]) is not Tensor for target in targets): + raise ValueError(f"Expected all {ivn} in `target` to be of type Tensor") if any(type(target["labels"]) is not Tensor for target in targets): raise ValueError("Expected all labels in `target` to be of type Tensor") for i, item in enumerate(targets): - if item[item_val_name].size(0) != item["labels"].size(0): - raise ValueError( - f"Input {item_val_name} and labels of sample {i} in targets have a" - f" different length (expected {item[item_val_name].size(0)} labels, got {item['labels'].size(0)})" - ) + for ivn in item_val_name: + if item[ivn].size(0) != item["labels"].size(0): + raise ValueError( + f"Input {ivn} and labels of sample {i} in targets have a" + f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})" + ) for i, item in enumerate(preds): - if not (item[item_val_name].size(0) == item["labels"].size(0) == item["scores"].size(0)): - raise ValueError( - f"Input {item_val_name}, labels and scores of sample {i} in predictions have a" - f" different length (expected {item[item_val_name].size(0)} labels and scores," - f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" - ) + for ivn in item_val_name: + if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)): + raise ValueError( + f"Input {ivn}, labels and scores of sample {i} in predictions have a" + f" different length (expected {item[ivn].size(0)} labels and scores," + f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" + ) def _fix_empty_tensors(boxes: Tensor) -> Tensor: @@ -75,3 +85,14 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: if boxes.numel() == 0 and boxes.ndim == 1: return boxes.unsqueeze(0) return boxes + + +def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox") -> List[str]: + allowed_iou_types = ("segm", "bbox") + if not isinstance(iou_type, list): + iou_type = [iou_type] + if any(i_type not in allowed_iou_types for i_type in iou_type): + raise ValueError( + f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}" + ) + return iou_type diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 5542372d436..42382c8b6f6 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -22,7 +22,7 @@ from torch import distributed as dist from typing_extensions import Literal -from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator +from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator, _validate_iou_type_arg from torchmetrics.metric import Metric from torchmetrics.utilities.imports import ( _MATPLOTLIB_AVAILABLE, @@ -230,10 +230,12 @@ class MeanAveragePrecision(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - detections: List[Tensor] + detection_box: List[Tensor] + detection_mask: List[Tensor] detection_scores: List[Tensor] detection_labels: List[Tensor] - groundtruths: List[Tensor] + groundtruth_box: List[Tensor] + groundtruth_mask: List[Tensor] groundtruth_labels: List[Tensor] groundtruth_crowds: List[Tensor] groundtruth_area: List[Tensor] @@ -241,7 +243,7 @@ class MeanAveragePrecision(Metric): def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", - iou_type: Literal["bbox", "segm"] = "bbox", + iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", iou_thresholds: Optional[List[float]] = None, rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, @@ -266,10 +268,7 @@ def __init__( raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") self.box_format = box_format - allowed_iou_types = ("segm", "bbox") - if iou_type not in allowed_iou_types: - raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}") - self.iou_type = iou_type + self.iou_type = _validate_iou_type_arg(iou_type) if iou_thresholds is not None and not isinstance(iou_thresholds, list): raise ValueError( @@ -295,10 +294,12 @@ def __init__( raise ValueError("Expected argument `class_metrics` to be a boolean") self.class_metrics = class_metrics - self.add_state("detections", default=[], dist_reduce_fx=None) + self.add_state("detection_box", default=[], dist_reduce_fx=None) + self.add_state("detection_mask", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) self.add_state("detection_labels", default=[], dist_reduce_fx=None) - self.add_state("groundtruths", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_box", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_mask", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None) self.add_state("groundtruth_area", default=[], dist_reduce_fx=None) @@ -327,15 +328,20 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] _input_validator(preds, target, iou_type=self.iou_type) for item in preds: - detections = self._get_safe_item_values(item) - - self.detections.append(detections) + bbox_detection, mask_detection = self._get_safe_item_values(item) + if bbox_detection is not None: + self.detection_box.append(bbox_detection) + if mask_detection is not None: + self.detection_mask.append(mask_detection) self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) for item in target: - groundtruths = self._get_safe_item_values(item) - self.groundtruths.append(groundtruths) + bbox_groundtruth, mask_groundtruth = self._get_safe_item_values(item) + if bbox_groundtruth is not None: + self.groundtruth_box.append(bbox_groundtruth) + if mask_groundtruth is not None: + self.groundtruth_mask.append(mask_groundtruth) self.groundtruth_labels.append(item["labels"]) self.groundtruth_crowds.append(item.get("iscrowd", torch.zeros_like(item["labels"]))) self.groundtruth_area.append(item.get("area", torch.zeros_like(item["labels"]))) @@ -345,68 +351,97 @@ def compute(self) -> dict: coco_target, coco_preds = COCO(), COCO() coco_target.dataset = self._get_coco_format( - self.groundtruths, self.groundtruth_labels, crowds=self.groundtruth_crowds, area=self.groundtruth_area + labels=self.groundtruth_labels, + boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None, + masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None, + crowds=self.groundtruth_crowds, + area=self.groundtruth_area, + ) + coco_preds.dataset = self._get_coco_format( + labels=self.detection_labels, + boxes=self.detection_box if len(self.detection_box) > 0 else None, + masks=self.detection_mask if len(self.detection_mask) > 0 else None, + scores=self.detection_scores, ) - coco_preds.dataset = self._get_coco_format(self.detections, self.detection_labels, scores=self.detection_scores) + result_dict = {} with contextlib.redirect_stdout(io.StringIO()): coco_target.createIndex() coco_preds.createIndex() - self.coco_eval = COCOeval(coco_target, coco_preds, iouType=self.iou_type) - self.coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) - self.coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) - self.coco_eval.params.maxDets = self.max_detection_thresholds - - self.coco_eval.evaluate() - self.coco_eval.accumulate() - self.coco_eval.summarize() - stats = self.coco_eval.stats - - # if class mode is enabled, evaluate metrics per class - if self.class_metrics: - map_per_class_list = [] - mar_100_per_class_list = [] - for class_id in self._get_classes(): - self.coco_eval.params.catIds = [class_id] - with contextlib.redirect_stdout(io.StringIO()): - self.coco_eval.evaluate() - self.coco_eval.accumulate() - self.coco_eval.summarize() - class_stats = self.coco_eval.stats - - map_per_class_list.append(torch.tensor([class_stats[0]])) - mar_100_per_class_list.append(torch.tensor([class_stats[8]])) - - map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) - mar_100_per_class_values = torch.tensor(mar_100_per_class_list, dtype=torch.float32) - else: - map_per_class_values = torch.tensor([-1], dtype=torch.float32) - mar_100_per_class_values = torch.tensor([-1], dtype=torch.float32) + for i_type in self.iou_type: + if len(self.iou_type) > 1: + # the area calculation is different for bbox and segm and therefore to get the small, medium and + # large values correct we need to dynamically change the area attribute of the annotations + for anno in coco_preds.dataset["annotations"]: + anno["area"] = anno[f"area_{i_type}"] + + self.coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type) + self.coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) + self.coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) + self.coco_eval.params.maxDets = self.max_detection_thresholds + + self.coco_eval.evaluate() + self.coco_eval.accumulate() + self.coco_eval.summarize() + stats = self.coco_eval.stats + result_dict.update( + self._coco_stats_to_tensor_dict(stats, prefix=i_type if len(self.iou_type) > 1 else None) + ) + # if class mode is enabled, evaluate metrics per class + if self.class_metrics: + map_per_class_list = [] + mar_100_per_class_list = [] + for class_id in self._get_classes(): + self.coco_eval.params.catIds = [class_id] + with contextlib.redirect_stdout(io.StringIO()): + self.coco_eval.evaluate() + self.coco_eval.accumulate() + self.coco_eval.summarize() + class_stats = self.coco_eval.stats + + map_per_class_list.append(torch.tensor([class_stats[0]])) + mar_100_per_class_list.append(torch.tensor([class_stats[8]])) + + map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) + mar_100_per_class_values = torch.tensor(mar_100_per_class_list, dtype=torch.float32) + else: + map_per_class_values = torch.tensor([-1], dtype=torch.float32) + mar_100_per_class_values = torch.tensor([-1], dtype=torch.float32) + prefix = "" if len(self.iou_type) == 1 else f"{i_type}_" + result_dict.update( + { + f"{prefix}map_per_class": map_per_class_values, + f"{prefix}mar_100_per_class": mar_100_per_class_values, + }, + ) + result_dict.update({"classes": torch.tensor(self._get_classes(), dtype=torch.int32)}) + + return result_dict + @staticmethod + def _coco_stats_to_tensor_dict(stats: List[float], prefix: Optional[str] = None) -> Dict[str, Tensor]: + prefix = "" if prefix is None else prefix + "_" return { - "map": torch.tensor([stats[0]], dtype=torch.float32), - "map_50": torch.tensor([stats[1]], dtype=torch.float32), - "map_75": torch.tensor([stats[2]], dtype=torch.float32), - "map_small": torch.tensor([stats[3]], dtype=torch.float32), - "map_medium": torch.tensor([stats[4]], dtype=torch.float32), - "map_large": torch.tensor([stats[5]], dtype=torch.float32), - "mar_1": torch.tensor([stats[6]], dtype=torch.float32), - "mar_10": torch.tensor([stats[7]], dtype=torch.float32), - "mar_100": torch.tensor([stats[8]], dtype=torch.float32), - "mar_small": torch.tensor([stats[9]], dtype=torch.float32), - "mar_medium": torch.tensor([stats[10]], dtype=torch.float32), - "mar_large": torch.tensor([stats[11]], dtype=torch.float32), - "map_per_class": map_per_class_values, - "mar_100_per_class": mar_100_per_class_values, - "classes": torch.tensor(self._get_classes(), dtype=torch.int32), + f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), + f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), + f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32), + f"{prefix}map_small": torch.tensor([stats[3]], dtype=torch.float32), + f"{prefix}map_medium": torch.tensor([stats[4]], dtype=torch.float32), + f"{prefix}map_large": torch.tensor([stats[5]], dtype=torch.float32), + f"{prefix}mar_1": torch.tensor([stats[6]], dtype=torch.float32), + f"{prefix}mar_10": torch.tensor([stats[7]], dtype=torch.float32), + f"{prefix}mar_100": torch.tensor([stats[8]], dtype=torch.float32), + f"{prefix}mar_small": torch.tensor([stats[9]], dtype=torch.float32), + f"{prefix}mar_medium": torch.tensor([stats[10]], dtype=torch.float32), + f"{prefix}mar_large": torch.tensor([stats[11]], dtype=torch.float32), } @staticmethod def coco_to_tm( coco_preds: str, coco_target: str, - iou_type: Literal["bbox", "segm"] = "bbox", + iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", ) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]: """Utility function for converting .json coco format files to the input format of this metric. @@ -434,6 +469,8 @@ def coco_to_tm( ... ) # doctest: +SKIP """ + iou_type = _validate_iou_type_arg(iou_type) + with contextlib.redirect_stdout(io.StringIO()): gt = COCO(coco_target) dt = gt.loadRes(coco_preds) @@ -445,14 +482,18 @@ def coco_to_tm( for t in gt_dataset: if t["image_id"] not in target: target[t["image_id"]] = { - "boxes" if iou_type == "bbox" else "masks": [], "labels": [], "iscrowd": [], "area": [], } - if iou_type == "bbox": + if "bbox" in iou_type: + target[t["image_id"]]["boxes"] = [] + if "segm" in iou_type: + target[t["image_id"]]["masks"] = [] + + if "bbox" in iou_type: target[t["image_id"]]["boxes"].append(t["bbox"]) - else: + if "segm" in iou_type: target[t["image_id"]]["masks"].append(gt.annToMask(t)) target[t["image_id"]]["labels"].append(t["category_id"]) target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) @@ -461,39 +502,47 @@ def coco_to_tm( preds = {} for p in dt_dataset: if p["image_id"] not in preds: - preds[p["image_id"]] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} - if iou_type == "bbox": + preds[p["image_id"]] = {"scores": [], "labels": []} + if "bbox" in iou_type: + preds[p["image_id"]]["boxes"] = [] + if "segm" in iou_type: + preds[p["image_id"]]["masks"] = [] + if "bbox" in iou_type: preds[p["image_id"]]["boxes"].append(p["bbox"]) - else: + if "segm" in iou_type: preds[p["image_id"]]["masks"].append(gt.annToMask(p)) preds[p["image_id"]]["scores"].append(p["score"]) preds[p["image_id"]]["labels"].append(p["category_id"]) for k in target: # add empty predictions for images without predictions if k not in preds: - preds[k] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} + preds[k] = {"scores": [], "labels": []} + if "bbox" in iou_type: + preds[k]["boxes"] = [] + if "segm" in iou_type: + preds[k]["masks"] = [] batched_preds, batched_target = [], [] for key in target: - name = "boxes" if iou_type == "bbox" else "masks" - batched_preds.append( - { - name: torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) - if iou_type == "bbox" - else torch.tensor(np.array(preds[key]["masks"]), dtype=torch.uint8), - "scores": torch.tensor(preds[key]["scores"], dtype=torch.float32), - "labels": torch.tensor(preds[key]["labels"], dtype=torch.int32), - } - ) - batched_target.append( - { - name: torch.tensor(target[key]["boxes"], dtype=torch.float32) - if iou_type == "bbox" - else torch.tensor(np.array(target[key]["masks"]), dtype=torch.uint8), - "labels": torch.tensor(target[key]["labels"], dtype=torch.int32), - "iscrowd": torch.tensor(target[key]["iscrowd"], dtype=torch.int32), - "area": torch.tensor(target[key]["area"], dtype=torch.float32), - } - ) + bp = { + "scores": torch.tensor(preds[key]["scores"], dtype=torch.float32), + "labels": torch.tensor(preds[key]["labels"], dtype=torch.int32), + } + if "bbox" in iou_type: + bp["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) + if "segm" in iou_type: + bp["masks"] = torch.tensor(np.array(preds[key]["masks"]), dtype=torch.uint8) + batched_preds.append(bp) + + bt = { + "labels": torch.tensor(target[key]["labels"], dtype=torch.int32), + "iscrowd": torch.tensor(target[key]["iscrowd"], dtype=torch.int32), + "area": torch.tensor(target[key]["area"], dtype=torch.float32), + } + if "bbox" in iou_type: + bt["boxes"] = torch.tensor(target[key]["boxes"], dtype=torch.float32) + if "segm" in iou_type: + bt["masks"] = torch.tensor(np.array(target[key]["masks"]), dtype=torch.uint8) + batched_target.append(bt) return batched_preds, batched_target @@ -528,8 +577,16 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: >>> metric.tm_to_coco("tm_map_input") # doctest: +SKIP """ - target_dataset = self._get_coco_format(self.groundtruths, self.groundtruth_labels) - preds_dataset = self._get_coco_format(self.detections, self.detection_labels, self.detection_scores) + target_dataset = self._get_coco_format( + labels=self.groundtruth_labels, + boxes=self.groundtruth_box, + masks=self.groundtruth_mask, + crowds=self.groundtruth_crowds, + area=self.groundtruth_area, + ) + preds_dataset = self._get_coco_format( + labels=self.detection_labels, boxes=self.detection_box, masks=self.detection_mask + ) preds_json = json.dumps(preds_dataset["annotations"], indent=4) target_json = json.dumps(target_dataset, indent=4) @@ -540,7 +597,7 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: with open(f"{name}_target.json", "w") as f: f.write(target_json) - def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: + def _get_safe_item_values(self, item: Dict[str, Any]) -> Tuple[Optional[Tensor], Optional[Tuple]]: """Convert and return the boxes or masks from the item depending on the iou_type. Args: @@ -550,18 +607,19 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: boxes or masks depending on the iou_type """ - if self.iou_type == "bbox": + output = [None, None] + if "bbox" in self.iou_type: boxes = _fix_empty_tensors(item["boxes"]) if boxes.numel() > 0: boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xywh") - return boxes - if self.iou_type == "segm": + output[0] = boxes + if "segm" in self.iou_type: masks = [] for i in item["masks"].cpu().numpy(): rle = mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) - return tuple(masks) - raise Exception(f"IOU type {self.iou_type} is not supported") + output[1] = tuple(masks) + return output def _get_classes(self) -> List: """Return a list of unique classes found in ground truth and detection data.""" @@ -571,8 +629,9 @@ def _get_classes(self) -> List: def _get_coco_format( self, - boxes: List[torch.Tensor], labels: List[torch.Tensor], + boxes: Optional[List[torch.Tensor]] = None, + masks: Optional[List[torch.Tensor]] = None, scores: Optional[List[torch.Tensor]] = None, crowds: Optional[List[torch.Tensor]] = None, area: Optional[List[torch.Tensor]] = None, @@ -585,20 +644,28 @@ def _get_coco_format( annotations = [] annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong - for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): - if self.iou_type == "segm" and len(image_boxes) == 0: - continue - - if self.iou_type == "bbox": + for image_id, image_labels in enumerate(labels): + if boxes is not None: + image_boxes = boxes[image_id] image_boxes = image_boxes.cpu().tolist() + if masks is not None: + image_masks = masks[image_id] + if len(image_masks) == 0 and boxes is None: + continue image_labels = image_labels.cpu().tolist() images.append({"id": image_id}) - if self.iou_type == "segm": - images[-1]["height"], images[-1]["width"] = image_boxes[0][0][0], image_boxes[0][0][1] + if "segm" in self.iou_type and len(image_masks) > 0: + images[-1]["height"], images[-1]["width"] = image_masks[0][0][0], image_masks[0][0][1] + + for k, image_label in enumerate(image_labels): + if boxes is not None: + image_box = image_boxes[k] + if masks is not None and len(image_masks) > 0: + image_mask = image_masks[k] + image_mask = {"size": image_mask[0], "counts": image_mask[1]} - for k, (image_box, image_label) in enumerate(zip(image_boxes, image_labels)): - if self.iou_type == "bbox" and len(image_box) != 4: + if "bbox" in self.iou_type and len(image_box) != 4: raise ValueError( f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})" ) @@ -609,21 +676,31 @@ def _get_coco_format( f" (expected value of type integer, got type {type(image_label)})" ) - stat = image_box if self.iou_type == "bbox" else {"size": image_box[0], "counts": image_box[1]} - + area_stat_box = None + area_stat_mask = None if area is not None and area[image_id][k].cpu().tolist() > 0: area_stat = area[image_id][k].cpu().tolist() else: - area_stat = image_box[2] * image_box[3] if self.iou_type == "bbox" else mask_utils.area(stat) + area_stat = mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3] + if len(self.iou_type) > 1: + area_stat_box = image_box[2] * image_box[3] + area_stat_mask = mask_utils.area(image_mask) annotation = { "id": annotation_id, "image_id": image_id, - "bbox" if self.iou_type == "bbox" else "segmentation": stat, "area": area_stat, "category_id": image_label, "iscrowd": crowds[image_id][k].cpu().tolist() if crowds is not None else 0, } + if area_stat_box is not None: + annotation["area_bbox"] = area_stat_box + annotation["area_segm"] = area_stat_mask + + if boxes is not None: + annotation["bbox"] = image_box + if masks is not None: + annotation["segmentation"] = image_mask if scores is not None: score = scores[image_id][k].cpu().tolist() @@ -707,7 +784,7 @@ def _apply(self, fn: Callable) -> torch.nn.Module: # type: ignore[override] Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is no longer a tensor but a tuple. """ - return super()._apply(fn, exclude_state=("detections", "groundtruths") if self.iou_type == "segm" else "") + return super()._apply(fn, exclude_state=("detection_mask", "groundtruth_mask")) def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None: """Custom sync function. @@ -718,9 +795,9 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt """ super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) - if self.iou_type == "segm": - self.detections = self._gather_tuple_list(self.detections, process_group) - self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) + if "segm" in self.iou_type: + self.detections = self._gather_tuple_list(self.detection_mask, process_group) + self.groundtruths = self._gather_tuple_list(self.groundtruth_mask, process_group) @staticmethod def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: From cf9726350bd61bf1af0712799153de7177b4901a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 19 Jul 2023 14:26:27 +0200 Subject: [PATCH 02/12] add test --- tests/unittests/detection/test_map.py | 43 +++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index f1c002428a0..37c5cb39988 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib import io +import json from collections import namedtuple from copy import deepcopy from functools import partial @@ -54,7 +55,7 @@ def _generate_coco_inputs(iou_type): _coco_segm_input = _generate_coco_inputs("segm") -def _compare_again_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_thresholds=None, class_metrics=True): +def _compare_against_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_thresholds=None, class_metrics=True): """Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb.""" with contextlib.redirect_stdout(io.StringIO()): gt = COCO(_DETECTION_VAL) @@ -129,7 +130,7 @@ def test_map(self, iou_type, iou_thresholds, rec_thresholds, ddp): target=target, metric_class=MeanAveragePrecision, reference_metric=partial( - _compare_again_coco_fn, + _compare_against_coco_fn, iou_type=iou_type, iou_thresholds=iou_thresholds, rec_thresholds=rec_thresholds, @@ -154,13 +155,49 @@ def test_map_classwise(self, iou_type, ddp): preds=preds, target=target, metric_class=MeanAveragePrecision, - reference_metric=partial(_compare_again_coco_fn, iou_type=iou_type, class_metrics=True), + reference_metric=partial(_compare_against_coco_fn, iou_type=iou_type, class_metrics=True), metric_args={"box_format": "xywh", "iou_type": iou_type, "class_metrics": True}, check_batch=False, atol=1e-1, ) +def test_compare_both_same_time(tmpdir): + """Test that the class support evaluating both bbox and segm at the same time.""" + with open(_DETECTION_BBOX) as f: + boxes = json.load(f) + with open(_DETECTION_SEGM) as f: + segmentations = json.load(f) + combined = [{**box, **seg} for box, seg in zip(boxes, segmentations)] + with open(f"{tmpdir}/combined.json", "w") as f: + json.dump(combined, f) + batched_preds, batched_target = MeanAveragePrecision.coco_to_tm( + f"{tmpdir}/combined.json", _DETECTION_VAL, iou_type=["bbox", "segm"] + ) + batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(10)] + batched_target = [batched_target[10 * i : 10 * (i + 1)] for i in range(10)] + + metric = MeanAveragePrecision(iou_type=["bbox", "segm"], box_format="xywh") + for bp, bt in zip(batched_preds, batched_target): + metric.update(bp, bt) + res = metric.compute() + + res1 = _compare_against_coco_fn([], [], iou_type="bbox", class_metrics=False) + res2 = _compare_against_coco_fn([], [], iou_type="segm", class_metrics=False) + + for k, v in res1.items(): + if k == "classes": + continue + assert f"bbox_{k}" in res + assert torch.allclose(res[f"bbox_{k}"], v, atol=1e-2) + + for k, v in res2.items(): + if k == "classes": + continue + assert f"segm_{k}" in res + assert torch.allclose(res[f"segm_{k}"], v, atol=1e-2) + + Input = namedtuple("Input", ["preds", "target"]) From d3c6fa1a16dc4930da5d824be6d3df1a75a6338b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 19 Jul 2023 14:40:03 +0200 Subject: [PATCH 03/12] docs changes --- src/torchmetrics/detection/mean_ap.py | 71 ++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 42382c8b6f6..e1c10faecac 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -83,7 +83,8 @@ class MeanAveragePrecision(Metric): - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection boxes of the format specified in the constructor. - By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates, but can be changed + using the ``box_format`` parameter. Only required when `iou_type="bbox"`. - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes. - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for the boxes. @@ -94,7 +95,7 @@ class MeanAveragePrecision(Metric): (each dictionary corresponds to a single image). Parameters that should be provided per dict: - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth - boxes of the format specified in the constructor. + boxes of the format specified in the constructor. only required when `iou_type="bbox"`. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth classes for the boxes. @@ -136,7 +137,6 @@ class MeanAveragePrecision(Metric): .. note:: ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]. Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well. - The default properties are also accessible via fields and will raise an ``AttributeError`` if not available. .. note:: This metric utilizes the official `pycocotools` implementation as its backend. This means that the metric @@ -155,8 +155,8 @@ class MeanAveragePrecision(Metric): width and height. iou_type: - Type of input (either masks or bounding-boxes) used for computing IOU. - Supported IOU types are ``["bbox", "segm"]``. If using ``"segm"``, masks should be provided in input. + Type of input (either masks or bounding-boxes) used for computing IOU. Supported IOU types are + ``"bbox"`` or ``"segm"`` or both as a list. iou_thresholds: IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]`` with step ``0.05``. Else provide a list of floats. @@ -188,7 +188,10 @@ class MeanAveragePrecision(Metric): ValueError: If ``class_metrics`` is not a boolean - Example: + Example:: + + Basic example for when `iou_type="bbox"`: + >>> from torch import tensor >>> from torchmetrics.detection import MeanAveragePrecision >>> preds = [ @@ -204,7 +207,7 @@ class MeanAveragePrecision(Metric): ... labels=tensor([0]), ... ) ... ] - >>> metric = MeanAveragePrecision() + >>> metric = MeanAveragePrecision(iou_type="bbox") >>> metric.update(preds, target) >>> from pprint import pprint >>> pprint(metric.compute()) @@ -223,6 +226,60 @@ class MeanAveragePrecision(Metric): 'mar_large': tensor(0.6000), 'mar_medium': tensor(-1.), 'mar_small': tensor(-1.)} + + Example:: + + Basic example for when `iou_type="segm"`: + + >>> from torch import tensor + >>> from torchmetrics.detection import MeanAveragePrecision + >>> mask_pred = [ + ... [0, 0, 0, 0, 0], + ... [0, 0, 1, 1, 0], + ... [0, 0, 1, 1, 0], + ... [0, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0], + ... ] + >>> mask_tgt = [ + ... [0, 0, 0, 0, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, 1, 1, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, 0, 0, 0], + ... ] + >>> preds = [ + ... dict( + ... masks=tensor([mask_pred], dtype=torch.bool), + ... scores=tensor([0.536]), + ... labels=tensor([0]), + ... ) + ... ] + >>> target = [ + ... dict( + ... masks=tensor([mask_tgt], dtype=torch.bool), + ... labels=tensor([0]), + ... ) + ... ] + >>> metric = MeanAveragePrecision(iou_type="segm") + >>> metric.update(preds, target) + >>> from pprint import pprint + >>> pprint(metric.compute()) + {'classes': tensor(0, dtype=torch.int32), + 'map': tensor(0.2000), + 'map_50': tensor(1.), + 'map_75': tensor(0.), + 'map_large': tensor(-1.), + 'map_medium': tensor(-1.), + 'map_per_class': tensor(-1.), + 'map_small': tensor(0.2000), + 'mar_1': tensor(0.2000), + 'mar_10': tensor(0.2000), + 'mar_100': tensor(0.2000), + 'mar_100_per_class': tensor(-1.), + 'mar_large': tensor(-1.), + 'mar_medium': tensor(-1.), + 'mar_small': tensor(0.2000)} + """ is_differentiable: bool = False higher_is_better: Optional[bool] = True From d926d3b24dff29e2d540aa17b593414d22301957 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 19 Jul 2023 14:46:07 +0200 Subject: [PATCH 04/12] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb1f0b5ce05..13c17ef9740 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `VisualInformationFidelity` to image package ([#1830](https://github.com/Lightning-AI/torchmetrics/pull/1830)) +- Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1830](https://github.com/Lightning-AI/torchmetrics/pull/1928)) + + ### Changed - From e98551dd315dd34fe3ecfc49ada90f7ffdd94fc7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 19 Jul 2023 14:49:10 +0200 Subject: [PATCH 05/12] add example text --- src/torchmetrics/detection/mean_ap.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index e1c10faecac..18a790788ef 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -190,7 +190,8 @@ class MeanAveragePrecision(Metric): Example:: - Basic example for when `iou_type="bbox"`: + Basic example for when `iou_type="bbox"`. In this case the ``boxes`` key is required in the input dictionaries, + in addition to the ``scores`` and ``labels`` keys. >>> from torch import tensor >>> from torchmetrics.detection import MeanAveragePrecision @@ -229,7 +230,8 @@ class MeanAveragePrecision(Metric): Example:: - Basic example for when `iou_type="segm"`: + Basic example for when `iou_type="segm"`. In this case the ``masks`` key is required in the input dictionaries, + in addition to the ``scores`` and ``labels`` keys. >>> from torch import tensor >>> from torchmetrics.detection import MeanAveragePrecision From c0ead515c3dcd4ee29b2e9065048eeaf17d50202 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 25 Jul 2023 10:03:10 +0200 Subject: [PATCH 06/12] Update src/torchmetrics/detection/helpers.py Co-authored-by: vsuryamurthy --- src/torchmetrics/detection/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 5f90991bcd0..6a82c1b1a7c 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -25,13 +25,13 @@ def _input_validator( if not isinstance(iou_type, list): iou_type = [iou_type] + if any(i not in ["bbox", "segm"] for i in iou_type): + raise Exception(f"IOU type {iou_type} is not supported") item_val_name = [] if "bbox" in iou_type: item_val_name.append("boxes") if "segm" in iou_type: item_val_name.append("masks") - if any(i not in ["bbox", "segm"] for i in iou_type): - raise Exception(f"IOU type {iou_type} is not supported") if not isinstance(preds, Sequence): raise ValueError(f"Expected argument `preds` to be of type Sequence, but got {preds}") From 2a47a669811a10856bbf09e9514dc4099af888db Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 25 Jul 2023 12:24:24 +0200 Subject: [PATCH 07/12] fix syncronization error for segm --- src/torchmetrics/detection/mean_ap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 2676f0a5aed..f2b7fa69b5c 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -860,8 +860,8 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) if "segm" in self.iou_type: - self.detections = self._gather_tuple_list(self.detection_mask, process_group) - self.groundtruths = self._gather_tuple_list(self.groundtruth_mask, process_group) + self.detection_mask = self._gather_tuple_list(self.detection_mask, process_group) + self.groundtruth_mask = self._gather_tuple_list(self.groundtruth_mask, process_group) @staticmethod def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: From 3cbab0cc8bf0dacf0b8adac6d5f0152763870720 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 1 Aug 2023 14:47:29 +0200 Subject: [PATCH 08/12] Apply suggestions from code review Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/detection/helpers.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 6a82c1b1a7c..5c6046b5e70 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -25,13 +25,10 @@ def _input_validator( if not isinstance(iou_type, list): iou_type = [iou_type] - if any(i not in ["bbox", "segm"] for i in iou_type): + if any(tp not in {"bbox", "segm"} for tp in iou_type): raise Exception(f"IOU type {iou_type} is not supported") - item_val_name = [] - if "bbox" in iou_type: - item_val_name.append("boxes") - if "segm" in iou_type: - item_val_name.append("masks") + name_map = {"bbox": "boxes", "segm": "masks"} + item_val_name = [name_map[tp] for tp in iou_type] if not isinstance(preds, Sequence): raise ValueError(f"Expected argument `preds` to be of type Sequence, but got {preds}") @@ -67,14 +64,14 @@ def _input_validator( for ivn in item_val_name: if item[ivn].size(0) != item["labels"].size(0): raise ValueError( - f"Input {ivn} and labels of sample {i} in targets have a" + f"Input '{ivn}' and labels of sample {i} in targets have a" f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})" ) for i, item in enumerate(preds): for ivn in item_val_name: if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)): raise ValueError( - f"Input {ivn}, labels and scores of sample {i} in predictions have a" + f"Input '{ivn}', labels and scores of sample {i} in predictions have a" f" different length (expected {item[ivn].size(0)} labels and scores," f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" ) @@ -91,7 +88,7 @@ def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], List[str]] = allowed_iou_types = ("segm", "bbox") if not isinstance(iou_type, list): iou_type = [iou_type] - if any(i_type not in allowed_iou_types for i_type in iou_type): + if any(tp not in allowed_iou_types for tp in iou_type): raise ValueError( f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}" ) From 0e70ffb7ba89205f534b343154db6be7ba532b98 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 2 Aug 2023 13:16:24 +0200 Subject: [PATCH 09/12] Apply suggestions from code review Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/detection/helpers.py | 6 +++--- src/torchmetrics/detection/mean_ap.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 5c6046b5e70..9069b682b1c 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -19,10 +19,10 @@ def _input_validator( preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], - iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", + iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox", ) -> None: """Ensure the correct input format of `preds` and `targets`.""" - if not isinstance(iou_type, list): + if isinstance(iou_type, str): iou_type = [iou_type] if any(tp not in {"bbox", "segm"} for tp in iou_type): @@ -86,7 +86,7 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox") -> List[str]: allowed_iou_types = ("segm", "bbox") - if not isinstance(iou_type, list): + if isinstance(iou_type, str): iou_type = [iou_type] if any(tp not in allowed_iou_types for tp in iou_type): raise ValueError( diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index f2b7fa69b5c..f7a748e0205 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -156,7 +156,7 @@ class MeanAveragePrecision(Metric): iou_type: Type of input (either masks or bounding-boxes) used for computing IOU. Supported IOU types are - ``"bbox"`` or ``"segm"`` or both as a list. + ``"bbox"`` or ``"segm"`` or both as a tuple. iou_thresholds: IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]`` with step ``0.05``. Else provide a list of floats. From da318cf9d2337ad65acc1f4d8dce4a100cd8c654 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 2 Aug 2023 13:22:26 +0200 Subject: [PATCH 10/12] smaller corrections --- src/torchmetrics/detection/helpers.py | 8 ++++---- src/torchmetrics/detection/mean_ap.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 9069b682b1c..f4397f97895 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Literal, Sequence, Union +from typing import Dict, Literal, Sequence, Tuple, Union from torch import Tensor @@ -23,7 +23,7 @@ def _input_validator( ) -> None: """Ensure the correct input format of `preds` and `targets`.""" if isinstance(iou_type, str): - iou_type = [iou_type] + iou_type = (iou_type,) if any(tp not in {"bbox", "segm"} for tp in iou_type): raise Exception(f"IOU type {iou_type} is not supported") @@ -84,10 +84,10 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: return boxes -def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox") -> List[str]: +def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox") -> Tuple[str]: allowed_iou_types = ("segm", "bbox") if isinstance(iou_type, str): - iou_type = [iou_type] + iou_type = (iou_type,) if any(tp not in allowed_iou_types for tp in iou_type): raise ValueError( f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}" diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index f7a748e0205..248cd1b052c 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -302,7 +302,7 @@ class MeanAveragePrecision(Metric): def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", - iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", + iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox", iou_thresholds: Optional[List[float]] = None, rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, From b0b05b6c88d57f27af697cfae0f848e4b0c59b35 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 7 Aug 2023 15:40:13 +0200 Subject: [PATCH 11/12] fix mistake --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58bccc5c1d8..d994b3a444a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,9 +20,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `EditDistance` to text package ([#1906](https://github.com/Lightning-AI/torchmetrics/pull/1906)) -- Added warning to `PearsonCorrCoeff` if input has a very small variance for its given dtype ([#1926](https://github.com/Lightning-AI/torchmetrics/pull/1926)) - - - Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961)) From 21fbc6c096d66e005ef08b6b16dcebb1ec1c77e0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 7 Aug 2023 21:55:36 +0200 Subject: [PATCH 12/12] merge --- src/torchmetrics/detection/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index f4397f97895..f3681545b96 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -25,9 +25,9 @@ def _input_validator( if isinstance(iou_type, str): iou_type = (iou_type,) - if any(tp not in {"bbox", "segm"} for tp in iou_type): - raise Exception(f"IOU type {iou_type} is not supported") name_map = {"bbox": "boxes", "segm": "masks"} + if any(tp not in name_map for tp in iou_type): + raise Exception(f"IOU type {iou_type} is not supported") item_val_name = [name_map[tp] for tp in iou_type] if not isinstance(preds, Sequence):