Skip to content

Commit 62346d6

Browse files
committed
Some fixes
1 parent a273788 commit 62346d6

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

test/test_utils.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,16 @@ def test_draw_boxes():
116116
assert_equal(img, img_cp)
117117

118118

119-
def test_draw_boxes_dtypes():
119+
@pytest.mark.parametrize("fill", [True, False])
120+
def test_draw_boxes_dtypes(fill):
120121
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
121-
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes)
122+
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)
122123

123124
assert img_uint8 is not out_uint8
124125
assert out_uint8.dtype == torch.uint8
125126

126127
img_float = to_dtype(img_uint8, torch.float, scale=True)
127-
out_float = utils.draw_bounding_boxes(img_float, boxes)
128+
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)
128129

129130
assert img_float is not out_float
130131
assert out_float.is_floating_point()
@@ -168,7 +169,6 @@ def test_draw_boxes_grayscale():
168169

169170
def test_draw_invalid_boxes():
170171
img_tp = ((1, 1, 1), (1, 2, 3))
171-
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
172172
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
173173
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
174174
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
@@ -178,8 +178,6 @@ def test_draw_invalid_boxes():
178178

179179
with pytest.raises(TypeError, match="Tensor expected"):
180180
utils.draw_bounding_boxes(img_tp, boxes)
181-
with pytest.raises(ValueError, match="Tensor uint8 expected"):
182-
utils.draw_bounding_boxes(img_wrong1, boxes)
183181
with pytest.raises(ValueError, match="Pass individual images, not batches"):
184182
utils.draw_bounding_boxes(img_wrong2, boxes)
185183
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):

torchvision/utils.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def draw_bounding_boxes(
169169
If fill is True, Resulting Tensor should be saved as PNG image.
170170
171171
Args:
172-
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
172+
image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
173173
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
174174
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
175175
`0 <= ymin < ymax < H`.
@@ -188,6 +188,7 @@ def draw_bounding_boxes(
188188
Returns:
189189
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
190190
"""
191+
import torchvision.transforms.v2.functional as F # noqa
191192

192193
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
193194
_log_api_usage_once(draw_bounding_boxes)
@@ -217,11 +218,7 @@ def draw_bounding_boxes(
217218
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
218219
)
219220

220-
original_dtype = image.dtype
221-
colors = [
222-
torch.tensor(color, dtype=original_dtype, device=image.device)
223-
for color in _parse_colors(colors, num_objects=num_boxes, dtype=original_dtype)
224-
]
221+
colors = _parse_colors(colors, num_objects=num_boxes)
225222

226223
if font is None:
227224
if font_size is not None:
@@ -234,8 +231,11 @@ def draw_bounding_boxes(
234231
if image.size(0) == 1:
235232
image = torch.tile(image, (3, 1, 1))
236233

237-
ndarr = image.permute(1, 2, 0).cpu().numpy()
238-
img_to_draw = Image.fromarray(ndarr)
234+
original_dtype = image.dtype
235+
if original_dtype.is_floating_point:
236+
image = F.to_dtype(image, dtype=torch.uint8, scale=True)
237+
238+
img_to_draw = F.to_pil_image(image)
239239
img_boxes = boxes.to(torch.int64).tolist()
240240

241241
if fill:
@@ -254,7 +254,10 @@ def draw_bounding_boxes(
254254
margin = width + 1
255255
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
256256

257-
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=original_dtype)
257+
out = F.pil_to_tensor(img_to_draw)
258+
if original_dtype.is_floating_point:
259+
out = F.to_dtype(out, dtype=original_dtype, scale=True)
260+
return out
258261

259262

260263
@torch.no_grad()

0 commit comments

Comments
 (0)