Skip to content

Commit 2597fa4

Browse files
committed
Merge branch 'main' into transforms/pad
2 parents 4eba34a + 00c119c commit 2597fa4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+613
-205
lines changed

references/classification/transforms.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@ class RandomMixup(torch.nn.Module):
2121

2222
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
2323
super().__init__()
24-
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
25-
assert alpha > 0, "Alpha param can't be zero."
24+
25+
if num_classes < 1:
26+
raise ValueError(
27+
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
28+
)
29+
30+
if alpha <= 0:
31+
raise ValueError("Alpha param can't be zero.")
2632

2733
self.num_classes = num_classes
2834
self.p = p
@@ -99,8 +105,10 @@ class RandomCutmix(torch.nn.Module):
99105

100106
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
101107
super().__init__()
102-
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
103-
assert alpha > 0, "Alpha param can't be zero."
108+
if num_classes < 1:
109+
raise ValueError("Please provide a valid positive value for the num_classes.")
110+
if alpha <= 0:
111+
raise ValueError("Alpha param can't be zero.")
104112

105113
self.num_classes = num_classes
106114
self.p = p

references/detection/coco_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
class CocoEvaluator:
1414
def __init__(self, coco_gt, iou_types):
15-
assert isinstance(iou_types, (list, tuple))
15+
if not isinstance(iou_types, (list, tuple)):
16+
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
1617
coco_gt = copy.deepcopy(coco_gt)
1718
self.coco_gt = coco_gt
1819

references/detection/coco_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def _has_valid_annotation(anno):
126126
return True
127127
return False
128128

129-
assert isinstance(dataset, torchvision.datasets.CocoDetection)
129+
if not isinstance(dataset, torchvision.datasets.CocoDetection):
130+
raise TypeError(
131+
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
132+
)
130133
ids = []
131134
for ds_idx, img_id in enumerate(dataset.ids):
132135
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)

references/optical_flow/transforms.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module):
77
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
88
def forward(self, img1, img2, flow, valid_flow_mask):
99

10-
assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None)
11-
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None)
10+
if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
11+
raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
12+
if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
13+
raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")
1214

13-
assert img1.shape == img2.shape
15+
if img1.shape != img2.shape:
16+
raise ValueError("img1 and img2 should have the same shape.")
1417
h, w = img1.shape[-2:]
15-
if flow is not None:
16-
assert flow.shape == (2, h, w)
18+
if flow is not None and flow.shape != (2, h, w):
19+
raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
1720
if valid_flow_mask is not None:
18-
assert valid_flow_mask.shape == (h, w)
19-
assert valid_flow_mask.dtype == torch.bool
21+
if valid_flow_mask.shape != (h, w):
22+
raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
23+
if valid_flow_mask.dtype != torch.bool:
24+
raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")
2025

2126
return img1, img2, flow, valid_flow_mask
2227

@@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing):
109114
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
110115
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
111116
self.max_erase = max_erase
112-
assert self.max_erase > 0
117+
if self.max_erase <= 0:
118+
raise ValueError("max_raise should be greater than 0")
113119

114120
def forward(self, img1, img2, flow, valid_flow_mask):
115121
if torch.rand(1) > self.p:

references/optical_flow/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def update(self, **kwargs):
7171
for k, v in kwargs.items():
7272
if isinstance(v, torch.Tensor):
7373
v = v.item()
74-
assert isinstance(v, (float, int))
74+
if not isinstance(v, (float, int)):
75+
raise TypeError(
76+
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
77+
)
7578
self.meters[k].update(v)
7679

7780
def __getattr__(self, attr):

references/segmentation/coco_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def _has_valid_annotation(anno):
6868
# if more than 1k pixels occupied in the image
6969
return sum(obj["area"] for obj in anno) > 1000
7070

71-
assert isinstance(dataset, torchvision.datasets.CocoDetection)
71+
if not isinstance(dataset, torchvision.datasets.CocoDetection):
72+
raise TypeError(
73+
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
74+
)
75+
7276
ids = []
7377
for ds_idx, img_id in enumerate(dataset.ids):
7478
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)

references/segmentation/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def update(self, **kwargs):
118118
for k, v in kwargs.items():
119119
if isinstance(v, torch.Tensor):
120120
v = v.item()
121-
assert isinstance(v, (float, int))
121+
if not isinstance(v, (float, int)):
122+
raise TypeError(
123+
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
124+
)
122125
self.meters[k].update(v)
123126

124127
def __getattr__(self, attr):

references/similarity/sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def __init__(self, groups, p, k):
4747
self.groups = create_groups(groups, self.k)
4848

4949
# Ensures there are enough classes to sample from
50-
assert len(self.groups) >= p
50+
if len(self.groups) < p:
51+
raise ValueError("There are not enought classes to sample from")
5152

5253
def __iter__(self):
5354
# Shuffle samples within groups

references/video_classification/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def update(self, **kwargs):
7676
for k, v in kwargs.items():
7777
if isinstance(v, torch.Tensor):
7878
v = v.item()
79-
assert isinstance(v, (float, int))
79+
if not isinstance(v, (float, int)):
80+
raise TypeError(
81+
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
82+
)
8083
self.meters[k].update(v)
8184

8285
def __getattr__(self, attr):

test/test_backbone_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,16 @@ def test_build_fx_feature_extractor(self, model_name):
144144
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
145145
)
146146
# Check must specify return nodes
147-
with pytest.raises(AssertionError):
147+
with pytest.raises(ValueError):
148148
self._create_feature_extractor(model)
149149
# Check return_nodes and train_return_nodes / eval_return nodes
150150
# mutual exclusivity
151-
with pytest.raises(AssertionError):
151+
with pytest.raises(ValueError):
152152
self._create_feature_extractor(
153153
model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
154154
)
155155
# Check train_return_nodes / eval_return nodes must both be specified
156-
with pytest.raises(AssertionError):
156+
with pytest.raises(ValueError):
157157
self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
158158
# Check invalid node name raises ValueError
159159
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)