Skip to content

Commit 973f5cd

Browse files
[release/0.24] Cherry pick (#9217)
Co-authored-by: Andrei Moraru <[email protected]>
1 parent e437e35 commit 973f5cd

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed
766 Bytes
Loading

test/test_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,33 @@ def test_draw_boxes_with_coloured_label_backgrounds():
166166
assert_equal(result, expected)
167167

168168

169+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
170+
def test_draw_boxes_with_coloured_label_text_boxes():
171+
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
172+
labels = ["a", "b", "c", "d"]
173+
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
174+
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
175+
label_background_colors = ["white", "black", "yellow", "blue"]
176+
result = utils.draw_bounding_boxes(
177+
img,
178+
boxes,
179+
labels=labels,
180+
colors=colors,
181+
fill=True,
182+
label_colors=label_colors,
183+
label_background_colors=label_background_colors,
184+
fill_labels=True,
185+
)
186+
path = os.path.join(
187+
os.path.dirname(os.path.abspath(__file__)),
188+
"assets",
189+
"fakedata",
190+
"draw_boxes_different_label_background_colors.png",
191+
)
192+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
193+
assert_equal(result, expected)
194+
195+
169196
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
170197
def test_draw_rotated_boxes():
171198
img = torch.full((3, 500, 500), 255, dtype=torch.uint8)
@@ -382,7 +409,7 @@ def test_draw_segmentation_masks_errors(device):
382409
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
383410
with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
384411
utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
385-
with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
412+
with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"):
386413
bad_colors = np.array(["red", "blue"]) # should be a list
387414
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
388415
with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):

torchvision/utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111
from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont
1212

13-
1413
__all__ = [
1514
"_Image_fromarray",
1615
"make_grid",
@@ -293,6 +292,7 @@ def draw_bounding_boxes(
293292
font: Optional[str] = None,
294293
font_size: Optional[int] = None,
295294
label_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
295+
label_background_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
296296
fill_labels: bool = False,
297297
) -> torch.Tensor:
298298
"""
@@ -320,7 +320,10 @@ def draw_bounding_boxes(
320320
font_size (int): The requested font size in points.
321321
label_colors (color or list of colors, optional): Colors for the label text. See the description of the
322322
`colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True.
323-
fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False.
323+
label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the
324+
same colors used for the boxes. Ignored when ``fill_labels`` is False.
325+
fill_labels (bool): If `True` fills the label background with specified color (from the ``label_background_colors`` parameter,
326+
or from the ``colors`` parameter if not specified). Default: False.
324327
325328
Returns:
326329
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
@@ -356,12 +359,17 @@ def draw_bounding_boxes(
356359
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
357360
)
358361

359-
colors = _parse_colors(colors, num_objects=num_boxes)
362+
colors = _parse_colors(colors, num_objects=num_boxes) # type: ignore[assignment]
360363
if label_colors or fill_labels:
361364
label_colors = _parse_colors(label_colors if label_colors else "black", num_objects=num_boxes) # type: ignore[assignment]
362365
else:
363366
label_colors = colors.copy() # type: ignore[assignment]
364367

368+
if fill_labels and label_background_colors:
369+
label_background_colors = _parse_colors(label_background_colors, num_objects=num_boxes) # type: ignore[assignment]
370+
else:
371+
label_background_colors = colors.copy() # type: ignore[assignment]
372+
365373
if font is None:
366374
if font_size is not None:
367375
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
@@ -385,7 +393,7 @@ def draw_bounding_boxes(
385393
else:
386394
draw = _ImageDrawTV(img_to_draw)
387395

388-
for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type]
396+
for bbox, color, label, label_color, label_bg_color in zip(img_boxes, colors, labels, label_colors, label_background_colors): # type: ignore[arg-type]
389397
draw_method = draw.oriented_rectangle if len(bbox) > 4 else draw.rectangle
390398
fill_color = color + (100,) if fill else None
391399
draw_method(bbox, width=width, outline=color, fill=fill_color)
@@ -396,7 +404,7 @@ def draw_bounding_boxes(
396404
if fill_labels:
397405
left, top, right, bottom = draw.textbbox((bbox[0] + margin, bbox[1] + margin), label, font=txt_font)
398406
draw.rectangle(
399-
(left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=color
407+
(left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=label_bg_color # type: ignore[arg-type]
400408
)
401409
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) # type: ignore[arg-type]
402410

@@ -545,7 +553,7 @@ def draw_keypoints(
545553
if visibility.shape != keypoints.shape[:-1]:
546554
raise ValueError(
547555
"keypoints and visibility must have the same dimensionality for num_instances and K. "
548-
f"Got {visibility.shape = } and {keypoints.shape = }"
556+
f"Got {visibility.shape=} and {keypoints.shape=}"
549557
)
550558

551559
original_dtype = image.dtype
@@ -746,7 +754,7 @@ def _parse_colors(
746754
f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}."
747755
)
748756
elif not isinstance(colors, (tuple, str)):
749-
raise ValueError(f"`colors` must be a tuple or a string, or a list thereof, but got {colors}.")
757+
raise ValueError(f"colors must be a tuple or a string, or a list thereof, but got {colors}.")
750758
elif isinstance(colors, tuple) and len(colors) != 3:
751759
raise ValueError(f"If passed as tuple, colors should be an RGB triplet, but got {colors}.")
752760
else: # colors specifies a single color for all objects

0 commit comments

Comments
 (0)