@@ -116,6 +116,23 @@ def test_draw_boxes():
116
116
assert_equal (img , img_cp )
117
117
118
118
119
+ @pytest .mark .parametrize ("fill" , [True , False ])
120
+ def test_draw_boxes_dtypes (fill ):
121
+ img_uint8 = torch .full ((3 , 100 , 100 ), 255 , dtype = torch .uint8 )
122
+ out_uint8 = utils .draw_bounding_boxes (img_uint8 , boxes , fill = fill )
123
+
124
+ assert img_uint8 is not out_uint8
125
+ assert out_uint8 .dtype == torch .uint8
126
+
127
+ img_float = to_dtype (img_uint8 , torch .float , scale = True )
128
+ out_float = utils .draw_bounding_boxes (img_float , boxes , fill = fill )
129
+
130
+ assert img_float is not out_float
131
+ assert out_float .is_floating_point ()
132
+
133
+ torch .testing .assert_close (out_uint8 , to_dtype (out_float , torch .uint8 , scale = True ), rtol = 0 , atol = 1 )
134
+
135
+
119
136
@pytest .mark .parametrize ("colors" , [None , ["red" , "blue" , "#FF00FF" , (1 , 34 , 122 )], "red" , "#FF00FF" , (1 , 34 , 122 )])
120
137
def test_draw_boxes_colors (colors ):
121
138
img = torch .full ((3 , 100 , 100 ), 0 , dtype = torch .uint8 )
@@ -152,7 +169,6 @@ def test_draw_boxes_grayscale():
152
169
153
170
def test_draw_invalid_boxes ():
154
171
img_tp = ((1 , 1 , 1 ), (1 , 2 , 3 ))
155
- img_wrong1 = torch .full ((3 , 5 , 5 ), 255 , dtype = torch .float )
156
172
img_wrong2 = torch .full ((1 , 3 , 5 , 5 ), 255 , dtype = torch .uint8 )
157
173
img_correct = torch .zeros ((3 , 10 , 10 ), dtype = torch .uint8 )
158
174
boxes = torch .tensor ([[0 , 0 , 20 , 20 ], [0 , 0 , 0 , 0 ], [10 , 15 , 30 , 35 ], [23 , 35 , 93 , 95 ]], dtype = torch .float )
@@ -162,8 +178,6 @@ def test_draw_invalid_boxes():
162
178
163
179
with pytest .raises (TypeError , match = "Tensor expected" ):
164
180
utils .draw_bounding_boxes (img_tp , boxes )
165
- with pytest .raises (ValueError , match = "Tensor uint8 expected" ):
166
- utils .draw_bounding_boxes (img_wrong1 , boxes )
167
181
with pytest .raises (ValueError , match = "Pass individual images, not batches" ):
168
182
utils .draw_bounding_boxes (img_wrong2 , boxes )
169
183
with pytest .raises (ValueError , match = "Only grayscale and RGB images are supported" ):
0 commit comments