@@ -5675,18 +5675,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
5675
5675
5676
5676
5677
5677
class TestSanitizeBoundingBoxes :
5678
- @pytest .mark .parametrize ("min_size" , (1 , 10 ))
5679
- @pytest .mark .parametrize ("labels_getter" , ("default" , lambda inputs : inputs ["labels" ], None , lambda inputs : None ))
5680
- @pytest .mark .parametrize ("sample_type" , (tuple , dict ))
5681
- def test_transform (self , min_size , labels_getter , sample_type ):
5682
-
5683
- if sample_type is tuple and not isinstance (labels_getter , str ):
5684
- # The "lambda inputs: inputs["labels"]" labels_getter used in this test
5685
- # doesn't work if the input is a tuple.
5686
- return
5687
-
5688
- H , W = 256 , 128
5689
-
5678
+ def _get_boxes_and_valid_mask (self , H = 256 , W = 128 , min_size = 10 ):
5690
5679
boxes_and_validity = [
5691
5680
([0 , 1 , 10 , 1 ], False ), # Y1 == Y2
5692
5681
([0 , 1 , 0 , 20 ], False ), # X1 == X2
@@ -5706,18 +5695,31 @@ def test_transform(self, min_size, labels_getter, sample_type):
5706
5695
]
5707
5696
5708
5697
random .shuffle (boxes_and_validity ) # For test robustness: mix order of wrong and correct cases
5709
- boxes , is_valid_mask = zip (* boxes_and_validity )
5710
- valid_indices = [i for (i , is_valid ) in enumerate (is_valid_mask ) if is_valid ]
5711
-
5712
- boxes = torch .tensor (boxes )
5713
- labels = torch .arange (boxes .shape [0 ])
5698
+ boxes , expected_valid_mask = zip (* boxes_and_validity )
5714
5699
5715
5700
boxes = tv_tensors .BoundingBoxes (
5716
5701
boxes ,
5717
5702
format = tv_tensors .BoundingBoxFormat .XYXY ,
5718
5703
canvas_size = (H , W ),
5719
5704
)
5720
5705
5706
+ return boxes , expected_valid_mask
5707
+
5708
+ @pytest .mark .parametrize ("min_size" , (1 , 10 ))
5709
+ @pytest .mark .parametrize ("labels_getter" , ("default" , lambda inputs : inputs ["labels" ], None , lambda inputs : None ))
5710
+ @pytest .mark .parametrize ("sample_type" , (tuple , dict ))
5711
+ def test_transform (self , min_size , labels_getter , sample_type ):
5712
+
5713
+ if sample_type is tuple and not isinstance (labels_getter , str ):
5714
+ # The "lambda inputs: inputs["labels"]" labels_getter used in this test
5715
+ # doesn't work if the input is a tuple.
5716
+ return
5717
+
5718
+ H , W = 256 , 128
5719
+ boxes , expected_valid_mask = self ._get_boxes_and_valid_mask (H = H , W = W , min_size = min_size )
5720
+ valid_indices = [i for (i , is_valid ) in enumerate (expected_valid_mask ) if is_valid ]
5721
+
5722
+ labels = torch .arange (boxes .shape [0 ])
5721
5723
masks = tv_tensors .Mask (torch .randint (0 , 2 , size = (boxes .shape [0 ], H , W )))
5722
5724
whatever = torch .rand (10 )
5723
5725
input_img = torch .randint (0 , 256 , size = (1 , 3 , H , W ), dtype = torch .uint8 )
@@ -5763,6 +5765,44 @@ def test_transform(self, min_size, labels_getter, sample_type):
5763
5765
# This works because we conveniently set labels to arange(num_boxes)
5764
5766
assert out_labels .tolist () == valid_indices
5765
5767
5768
+ @pytest .mark .parametrize ("input_type" , (torch .Tensor , tv_tensors .BoundingBoxes ))
5769
+ def test_functional (self , input_type ):
5770
+ # Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
5771
+ # redundancy with test_transform() in terms of correctness checks. But that's OK.
5772
+
5773
+ H , W , min_size = 256 , 128 , 10
5774
+
5775
+ boxes , expected_valid_mask = self ._get_boxes_and_valid_mask (H = H , W = W , min_size = min_size )
5776
+
5777
+ if input_type is tv_tensors .BoundingBoxes :
5778
+ format = canvas_size = None
5779
+ else :
5780
+ # just passing "XYXY" explicitly to make sure we support strings
5781
+ format , canvas_size = "XYXY" , boxes .canvas_size
5782
+ boxes = boxes .as_subclass (torch .Tensor )
5783
+
5784
+ boxes , valid = F .sanitize_bounding_boxes (boxes , format = format , canvas_size = canvas_size , min_size = min_size )
5785
+
5786
+ assert_equal (valid , torch .tensor (expected_valid_mask ))
5787
+ assert type (valid ) == torch .Tensor
5788
+ assert boxes .shape [0 ] == sum (valid )
5789
+ assert isinstance (boxes , input_type )
5790
+
5791
+ def test_kernel (self ):
5792
+ H , W , min_size = 256 , 128 , 10
5793
+ boxes , _ = self ._get_boxes_and_valid_mask (H = H , W = W , min_size = min_size )
5794
+
5795
+ format , canvas_size = boxes .format , boxes .canvas_size
5796
+ boxes = boxes .as_subclass (torch .Tensor )
5797
+
5798
+ check_kernel (
5799
+ F .sanitize_bounding_boxes ,
5800
+ input = boxes ,
5801
+ format = format ,
5802
+ canvas_size = canvas_size ,
5803
+ check_batched_vs_unbatched = False ,
5804
+ )
5805
+
5766
5806
def test_no_label (self ):
5767
5807
# Non-regression test for https://github.com/pytorch/vision/issues/7878
5768
5808
@@ -5776,7 +5816,7 @@ def test_no_label(self):
5776
5816
assert isinstance (out_img , tv_tensors .Image )
5777
5817
assert isinstance (out_boxes , tv_tensors .BoundingBoxes )
5778
5818
5779
- def test_errors (self ):
5819
+ def test_errors_transform (self ):
5780
5820
good_bbox = tv_tensors .BoundingBoxes (
5781
5821
[[0 , 0 , 10 , 10 ]],
5782
5822
format = tv_tensors .BoundingBoxFormat .XYXY ,
@@ -5799,3 +5839,26 @@ def test_errors(self):
5799
5839
with pytest .raises (ValueError , match = "Number of boxes" ):
5800
5840
different_sizes = {"bbox" : good_bbox , "labels" : torch .arange (good_bbox .shape [0 ] + 3 )}
5801
5841
transforms .SanitizeBoundingBoxes ()(different_sizes )
5842
+
5843
+ def test_errors_functional (self ):
5844
+
5845
+ good_bbox = tv_tensors .BoundingBoxes (
5846
+ [[0 , 0 , 10 , 10 ]],
5847
+ format = tv_tensors .BoundingBoxFormat .XYXY ,
5848
+ canvas_size = (20 , 20 ),
5849
+ )
5850
+
5851
+ with pytest .raises (ValueError , match = "canvas_size cannot be None if bounding_boxes is a pure tensor" ):
5852
+ F .sanitize_bounding_boxes (good_bbox .as_subclass (torch .Tensor ), format = "XYXY" , canvas_size = None )
5853
+
5854
+ with pytest .raises (ValueError , match = "canvas_size cannot be None if bounding_boxes is a pure tensor" ):
5855
+ F .sanitize_bounding_boxes (good_bbox .as_subclass (torch .Tensor ), format = None , canvas_size = (10 , 10 ))
5856
+
5857
+ with pytest .raises (ValueError , match = "canvas_size must be None when bounding_boxes is a tv_tensors" ):
5858
+ F .sanitize_bounding_boxes (good_bbox , format = "XYXY" , canvas_size = None )
5859
+
5860
+ with pytest .raises (ValueError , match = "canvas_size must be None when bounding_boxes is a tv_tensors" ):
5861
+ F .sanitize_bounding_boxes (good_bbox , format = "XYXY" , canvas_size = None )
5862
+
5863
+ with pytest .raises (ValueError , match = "bouding_boxes must be a tv_tensors.BoundingBoxes instance or a" ):
5864
+ F .sanitize_bounding_boxes (good_bbox .tolist ())
0 commit comments