10
10
import torch
11
11
from PIL import __version__ as PILLOW_VERSION_STRING , Image , ImageColor , ImageDraw , ImageFont
12
12
13
-
14
13
__all__ = [
15
14
"_Image_fromarray" ,
16
15
"make_grid" ,
@@ -293,6 +292,7 @@ def draw_bounding_boxes(
293
292
font : Optional [str ] = None ,
294
293
font_size : Optional [int ] = None ,
295
294
label_colors : Optional [Union [list [Union [str , tuple [int , int , int ]]], str , tuple [int , int , int ]]] = None ,
295
+ label_background_colors : Optional [Union [list [Union [str , tuple [int , int , int ]]], str , tuple [int , int , int ]]] = None ,
296
296
fill_labels : bool = False ,
297
297
) -> torch .Tensor :
298
298
"""
@@ -320,7 +320,10 @@ def draw_bounding_boxes(
320
320
font_size (int): The requested font size in points.
321
321
label_colors (color or list of colors, optional): Colors for the label text. See the description of the
322
322
`colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True.
323
- fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False.
323
+ label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the
324
+ same colors used for the boxes. Ignored when ``fill_labels`` is False.
325
+ fill_labels (bool): If `True` fills the label background with specified color (from the ``label_background_colors`` parameter,
326
+ or from the ``colors`` parameter if not specified). Default: False.
324
327
325
328
Returns:
326
329
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
@@ -356,12 +359,17 @@ def draw_bounding_boxes(
356
359
f"Number of boxes ({ num_boxes } ) and labels ({ len (labels )} ) mismatch. Please specify labels for each box."
357
360
)
358
361
359
- colors = _parse_colors (colors , num_objects = num_boxes )
362
+ colors = _parse_colors (colors , num_objects = num_boxes ) # type: ignore[assignment]
360
363
if label_colors or fill_labels :
361
364
label_colors = _parse_colors (label_colors if label_colors else "black" , num_objects = num_boxes ) # type: ignore[assignment]
362
365
else :
363
366
label_colors = colors .copy () # type: ignore[assignment]
364
367
368
+ if fill_labels and label_background_colors :
369
+ label_background_colors = _parse_colors (label_background_colors , num_objects = num_boxes ) # type: ignore[assignment]
370
+ else :
371
+ label_background_colors = colors .copy () # type: ignore[assignment]
372
+
365
373
if font is None :
366
374
if font_size is not None :
367
375
warnings .warn ("Argument 'font_size' will be ignored since 'font' is not set." )
@@ -385,7 +393,7 @@ def draw_bounding_boxes(
385
393
else :
386
394
draw = _ImageDrawTV (img_to_draw )
387
395
388
- for bbox , color , label , label_color in zip (img_boxes , colors , labels , label_colors ): # type: ignore[arg-type]
396
+ for bbox , color , label , label_color , label_bg_color in zip (img_boxes , colors , labels , label_colors , label_background_colors ): # type: ignore[arg-type]
389
397
draw_method = draw .oriented_rectangle if len (bbox ) > 4 else draw .rectangle
390
398
fill_color = color + (100 ,) if fill else None
391
399
draw_method (bbox , width = width , outline = color , fill = fill_color )
@@ -396,7 +404,7 @@ def draw_bounding_boxes(
396
404
if fill_labels :
397
405
left , top , right , bottom = draw .textbbox ((bbox [0 ] + margin , bbox [1 ] + margin ), label , font = txt_font )
398
406
draw .rectangle (
399
- (left - box_margin , top - box_margin , right + box_margin , bottom + box_margin ), fill = color
407
+ (left - box_margin , top - box_margin , right + box_margin , bottom + box_margin ), fill = label_bg_color # type: ignore[arg-type]
400
408
)
401
409
draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), label , fill = label_color , font = txt_font ) # type: ignore[arg-type]
402
410
@@ -545,7 +553,7 @@ def draw_keypoints(
545
553
if visibility .shape != keypoints .shape [:- 1 ]:
546
554
raise ValueError (
547
555
"keypoints and visibility must have the same dimensionality for num_instances and K. "
548
- f"Got { visibility .shape = } and { keypoints .shape = } "
556
+ f"Got { visibility .shape = } and { keypoints .shape = } "
549
557
)
550
558
551
559
original_dtype = image .dtype
@@ -746,7 +754,7 @@ def _parse_colors(
746
754
f"Number of colors must be equal or larger than the number of objects, but got { len (colors )} < { num_objects } ."
747
755
)
748
756
elif not isinstance (colors , (tuple , str )):
749
- raise ValueError (f"` colors` must be a tuple or a string, or a list thereof, but got { colors } ." )
757
+ raise ValueError (f"colors must be a tuple or a string, or a list thereof, but got { colors } ." )
750
758
elif isinstance (colors , tuple ) and len (colors ) != 3 :
751
759
raise ValueError (f"If passed as tuple, colors should be an RGB triplet, but got { colors } ." )
752
760
else : # colors specifies a single color for all objects
0 commit comments