Skip to content

Commit

Permalink
typing: mark line ignores instead of ignore whole files (#2452)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 15, 2024
1 parent 31b99b6 commit b6fe6bb
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 85 deletions.
20 changes: 0 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,6 @@ disable_error_code = "attr-defined"
# style choices
warn_no_return = "False"

# Ignore mypy errors for these files
# TODO: the goal is for this to be empty
[[tool.mypy.overrides]]
module = [
"torchmetrics.classification.exact_match",
"torchmetrics.classification.f_beta",
"torchmetrics.classification.precision_recall",
"torchmetrics.classification.ranking",
"torchmetrics.classification.recall_at_fixed_precision",
"torchmetrics.classification.roc",
"torchmetrics.classification.stat_scores",
"torchmetrics.detection._mean_ap",
"torchmetrics.detection.mean_ap",
"torchmetrics.functional.image.psnr",
"torchmetrics.functional.image.ssim",
"torchmetrics.image.psnr",
"torchmetrics.image.ssim",
]
ignore_errors = "True"

[tool.typos.default]
extend-ignore-identifiers-re = [
# *sigh* this just isn't worth the cost of fixing
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class ExactMatch(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["ExactMatch"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ class FBetaScore(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["FBetaScore"],
task: Literal["binary", "multiclass", "multilabel"],
beta: float = 1.0,
Expand Down Expand Up @@ -1122,7 +1122,7 @@ class F1Score(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["F1Score"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ class Precision(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["Precision"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down Expand Up @@ -995,7 +995,7 @@ class Recall(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["Recall"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down Expand Up @@ -1028,4 +1028,4 @@ def __new__(
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelRecall(num_labels, threshold, average, **kwargs)
return None
return None # type: ignore[return-value]
8 changes: 4 additions & 4 deletions src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class BinaryROC(BinaryPrecisionRecallCurve):
def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_roc_compute(state, self.thresholds)
return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type]

def plot(
self,
Expand Down Expand Up @@ -290,7 +290,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average)
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) # type: ignore[arg-type]

def plot(
self,
Expand Down Expand Up @@ -449,7 +449,7 @@ class MultilabelROC(MultilabelPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index)
return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) # type: ignore[arg-type]

def plot(
self,
Expand Down Expand Up @@ -564,7 +564,7 @@ class ROC(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["ROC"],
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def _create_state(
def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None:
"""Update states depending on multidim_average argument."""
if self.multidim_average == "samplewise":
self.tp.append(tp)
self.fp.append(fp)
self.tn.append(tn)
self.fn.append(fn)
self.tp.append(tp) # type: ignore[union-attr]
self.fp.append(fp) # type: ignore[union-attr]
self.tn.append(tn) # type: ignore[union-attr]
self.fn.append(fn) # type: ignore[union-attr]
else:
self.tp += tp
self.fp += fp
Expand Down Expand Up @@ -515,7 +515,7 @@ class StatScores(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["StatScores"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
42 changes: 22 additions & 20 deletions src/torchmetrics/detection/_mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,18 @@ def __init__(

def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None:
"""Update state with predictions and targets."""
_input_validator(preds, target, iou_type=self.iou_type)
_input_validator(preds, target, iou_type=self.iou_type) # type: ignore[arg-type]

for item in preds:
detections = self._get_safe_item_values(item)

self.detections.append(detections)
self.detections.append(detections) # type: ignore[arg-type]
self.detection_labels.append(item["labels"])
self.detection_scores.append(item["scores"])

for item in target:
groundtruths = self._get_safe_item_values(item)
self.groundtruths.append(groundtruths)
self.groundtruths.append(groundtruths) # type: ignore[arg-type]
self.groundtruth_labels.append(item["labels"])

def _move_list_states_to_cpu(self) -> None:
Expand Down Expand Up @@ -640,13 +640,13 @@ def _find_best_gt_match(
Id of current detection.
"""
previously_matched = gt_matches[idx_iou]
previously_matched = gt_matches[idx_iou] # type: ignore[index]
# Remove previously matched or ignored gts
remove_mask = previously_matched | gt_ignore
gt_ious = ious[idx_det] * ~remove_mask
match_idx = gt_ious.argmax().item()
if gt_ious[match_idx] > thr:
return match_idx
if gt_ious[match_idx] > thr: # type: ignore[index]
return match_idx # type: ignore[return-value]
return -1

def _summarize(
Expand Down Expand Up @@ -713,7 +713,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
}

eval_imgs = [
self._evaluate_image(img_id, class_id, area, max_detections, ious)
self._evaluate_image(img_id, class_id, area, max_detections, ious) # type: ignore[arg-type]
for class_id in class_ids
for area in area_ranges
for img_id in img_ids
Expand Down Expand Up @@ -750,7 +750,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
num_bbox_areas=num_bbox_areas,
)

return precision, recall
return precision, recall # type: ignore[return-value]

def _summarize_results(self, precisions: Tensor, recalls: Tensor) -> Tuple[MAPMetricResults, MARMetricResults]:
"""Summarizes the precision and recall values to calculate mAP/mAR.
Expand Down Expand Up @@ -820,8 +820,8 @@ def __calculate_recall_precision_scores(
inds = torch.argsort(det_scores.to(dtype), descending=True)
det_scores_sorted = det_scores[inds]

det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds]
det_matches = torch.cat([e["dtMatches"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload]
det_ignore = torch.cat([e["dtIgnore"][:, :max_det] for e in img_eval_cls_bbox], axis=1)[:, inds] # type: ignore[call-overload]
gt_ignore = torch.cat([e["gtIgnore"] for e in img_eval_cls_bbox])
npig = torch.count_nonzero(gt_ignore == False) # noqa: E712
if npig == 0:
Expand Down Expand Up @@ -849,9 +849,9 @@ def __calculate_recall_precision_scores(

inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False)
num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs
inds = inds[:num_inds]
prec[:num_inds] = pr[inds]
score[:num_inds] = det_scores_sorted[inds]
inds = inds[:num_inds] # type: ignore[misc]
prec[:num_inds] = pr[inds] # type: ignore[misc]
score[:num_inds] = det_scores_sorted[inds] # type: ignore[misc]
precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = prec
scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thrs] = score

Expand All @@ -861,7 +861,7 @@ def compute(self) -> dict:
"""Compute metric."""
classes = self._get_classes()
precisions, recalls = self._calculate(classes)
map_val, mar_val = self._summarize_results(precisions, recalls)
map_val, mar_val = self._summarize_results(precisions, recalls) # type: ignore[arg-type]

# if class mode is enabled, evaluate metrics per class
map_per_class_values: Tensor = torch.tensor([-1.0])
Expand All @@ -888,7 +888,7 @@ def compute(self) -> dict:
metrics.classes = torch.tensor(classes, dtype=torch.int)
return metrics

def _apply(self, fn: Callable) -> torch.nn.Module:
def _apply(self, fn: Callable) -> torch.nn.Module: # type: ignore[override]
"""Custom apply function.
Excludes the detections and groundtruths from the casting when the iou_type is set to `segm` as the state is
Expand All @@ -908,22 +908,24 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt
to gather the list of tuples and then convert it back to a list of tuples.
"""
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group)
super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) # type: ignore[arg-type]

if self.iou_type == "segm":
self.detections = self._gather_tuple_list(self.detections, process_group)
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group)
self.detections = self._gather_tuple_list(self.detections, process_group) # type: ignore[arg-type]
self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) # type: ignore[arg-type]

@staticmethod
def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]:
def _gather_tuple_list(
list_to_gather: List[Union[tuple, Tensor]], process_group: Optional[Any] = None
) -> List[Any]:
"""Gather a list of tuples over multiple devices."""
world_size = dist.get_world_size(group=process_group)
dist.barrier(group=process_group)

list_gathered = [None for _ in range(world_size)]
dist.all_gather_object(list_gathered, list_to_gather, group=process_group)

return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)]
return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] # type: ignore[arg-type,index]

def plot(
self, val: Optional[Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/detection/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor:
return boxes


def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox") -> Tuple[str]:
def _validate_iou_type_arg(
iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox",
) -> Tuple[str]:
"""Validate that iou type argument is correct."""
allowed_iou_types = ("segm", "bbox")
if isinstance(iou_type, str):
iou_type = (iou_type,)
if any(tp not in allowed_iou_types for tp in iou_type):
raise ValueError(
f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}"
f"Expected argument `iou_type` to be one of {allowed_iou_types} or a tuple of, but got {iou_type}"
)
return iou_type
Loading

0 comments on commit b6fe6bb

Please sign in to comment.