@@ -959,6 +959,14 @@ def area_check(box, expected, tolerance=1e-4):
959959 expected = torch .tensor ([605113.875 , 600495.1875 , 592247.25 ])
960960 area_check (box_tensor , expected )
961961
962+ def test_box_area_jit (self ):
963+ box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float )
964+ TOLERANCE = 1e-3
965+ expected = ops .box_area (box_tensor )
966+ scripted_fn = torch .jit .script (ops .box_area )
967+ scripted_area = scripted_fn (box_tensor )
968+ torch .testing .assert_close (scripted_area , expected , rtol = 0.0 , atol = TOLERANCE )
969+
962970
963971class TestBoxIou :
964972 def test_iou (self ):
@@ -980,6 +988,14 @@ def iou_check(box, expected, tolerance=1e-4):
980988 expected = torch .tensor ([[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]])
981989 iou_check (box_tensor , expected , tolerance = 0.002 if dtype == torch .float16 else 1e-4 )
982990
991+ def test_iou_jit (self ):
992+ box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = torch .float )
993+ TOLERANCE = 1e-3
994+ expected = ops .box_iou (box_tensor , box_tensor )
995+ scripted_fn = torch .jit .script (ops .box_iou )
996+ scripted_iou = scripted_fn (box_tensor , box_tensor )
997+ torch .testing .assert_close (scripted_iou , expected , rtol = 0.0 , atol = TOLERANCE )
998+
983999
9841000class TestGenBoxIou :
9851001 def test_gen_iou (self ):
@@ -1001,6 +1017,14 @@ def gen_iou_check(box, expected, tolerance=1e-4):
10011017 expected = torch .tensor ([[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]])
10021018 gen_iou_check (box_tensor , expected , tolerance = 0.002 if dtype == torch .float16 else 1e-3 )
10031019
1020+ def test_giou_jit (self ):
1021+ box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = torch .float )
1022+ TOLERANCE = 1e-3
1023+ expected = ops .generalized_box_iou (box_tensor , box_tensor )
1024+ scripted_fn = torch .jit .script (ops .generalized_box_iou )
1025+ scripted_iou = scripted_fn (box_tensor , box_tensor )
1026+ torch .testing .assert_close (scripted_iou , expected , rtol = 0.0 , atol = TOLERANCE )
1027+
10041028
10051029class TestMasksToBoxes :
10061030 def test_masks_box (self ):
0 commit comments