Skip to content

Commit 96640af

Browse files
add float support to utils.draw_bounding_boxes() (#8328)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 0367c21 commit 96640af

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

test/test_utils.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ def test_draw_boxes():
116116
assert_equal(img, img_cp)
117117

118118

119+
@pytest.mark.parametrize("fill", [True, False])
120+
def test_draw_boxes_dtypes(fill):
121+
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
122+
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)
123+
124+
assert img_uint8 is not out_uint8
125+
assert out_uint8.dtype == torch.uint8
126+
127+
img_float = to_dtype(img_uint8, torch.float, scale=True)
128+
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)
129+
130+
assert img_float is not out_float
131+
assert out_float.is_floating_point()
132+
133+
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
134+
135+
119136
@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
120137
def test_draw_boxes_colors(colors):
121138
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
@@ -152,7 +169,6 @@ def test_draw_boxes_grayscale():
152169

153170
def test_draw_invalid_boxes():
154171
img_tp = ((1, 1, 1), (1, 2, 3))
155-
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
156172
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
157173
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
158174
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
@@ -162,8 +178,6 @@ def test_draw_invalid_boxes():
162178

163179
with pytest.raises(TypeError, match="Tensor expected"):
164180
utils.draw_bounding_boxes(img_tp, boxes)
165-
with pytest.raises(ValueError, match="Tensor uint8 expected"):
166-
utils.draw_bounding_boxes(img_wrong1, boxes)
167181
with pytest.raises(ValueError, match="Pass individual images, not batches"):
168182
utils.draw_bounding_boxes(img_wrong2, boxes)
169183
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):

torchvision/utils.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def draw_bounding_boxes(
164164
) -> torch.Tensor:
165165

166166
"""
167-
Draws bounding boxes on given image.
168-
The values of the input image should be uint8 between 0 and 255.
167+
Draws bounding boxes on given RGB image.
168+
The image values should be uint8 in [0, 255] or float in [0, 1].
169169
If fill is True, Resulting Tensor should be saved as PNG image.
170170
171171
Args:
172-
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
172+
image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
173173
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
174174
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
175175
`0 <= ymin < ymax < H`.
@@ -188,13 +188,14 @@ def draw_bounding_boxes(
188188
Returns:
189189
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
190190
"""
191+
import torchvision.transforms.v2.functional as F # noqa
191192

192193
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
193194
_log_api_usage_once(draw_bounding_boxes)
194195
if not isinstance(image, torch.Tensor):
195196
raise TypeError(f"Tensor expected, got {type(image)}")
196-
elif image.dtype != torch.uint8:
197-
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
197+
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
198+
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
198199
elif image.dim() != 3:
199200
raise ValueError("Pass individual images, not batches")
200201
elif image.size(0) not in {1, 3}:
@@ -230,8 +231,11 @@ def draw_bounding_boxes(
230231
if image.size(0) == 1:
231232
image = torch.tile(image, (3, 1, 1))
232233

233-
ndarr = image.permute(1, 2, 0).cpu().numpy()
234-
img_to_draw = Image.fromarray(ndarr)
234+
original_dtype = image.dtype
235+
if original_dtype.is_floating_point:
236+
image = F.to_dtype(image, dtype=torch.uint8, scale=True)
237+
238+
img_to_draw = F.to_pil_image(image)
235239
img_boxes = boxes.to(torch.int64).tolist()
236240

237241
if fill:
@@ -250,7 +254,10 @@ def draw_bounding_boxes(
250254
margin = width + 1
251255
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
252256

253-
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
257+
out = F.pil_to_tensor(img_to_draw)
258+
if original_dtype.is_floating_point:
259+
out = F.to_dtype(out, dtype=original_dtype, scale=True)
260+
return out
254261

255262

256263
@torch.no_grad()

0 commit comments

Comments
 (0)