Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(self, src_dir):
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_rotated_box_transforms.py",
"plot_keypoints_transforms.py",
"plot_custom_transforms.py",
"plot_tv_tensors.py",
"plot_custom_tv_tensors.py",
Expand Down
Binary file added gallery/assets/pottery.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion gallery/transforms/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision.utils import draw_bounding_boxes, draw_keypoints, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F
Expand All @@ -18,6 +18,7 @@ def plot(imgs, row_title=None, bbox_width=3, **imshow_kwargs):
for col_idx, img in enumerate(row):
boxes = None
masks = None
points = None
if isinstance(img, tuple):
img, target = img
if isinstance(target, dict):
Expand All @@ -30,6 +31,8 @@ def plot(imgs, row_title=None, bbox_width=3, **imshow_kwargs):
# work with this specific format.
if tv_tensors.is_rotated_bounding_format(boxes.format):
boxes = v2.ConvertBoundingBoxFormat("xyxyxyxy")(boxes)
elif isinstance(target, tv_tensors.KeyPoints):
points = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
img = F.to_image(img)
Expand All @@ -44,6 +47,8 @@ def plot(imgs, row_title=None, bbox_width=3, **imshow_kwargs):
img = draw_bounding_boxes(img, boxes, colors="yellow", width=bbox_width)
if masks is not None:
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
if points is not None:
img = draw_keypoints(img, points, colors="red", radius=10)

ax = axs[row_idx, col_idx]
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
Expand Down
116 changes: 116 additions & 0 deletions gallery/transforms/plot_keypoints_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
===============================================================
Transforms on KeyPoints
===============================================================

This example illustrates how to define and use keypoints.
For this tutorial, we use this picture of a ceramic figure from the pre-columbian period.
The image is specified "public domain" (https://www.metmuseum.org/art/collection/search/502727).

.. note::
Support for keypoints was released in TorchVision 0.23 and is
currently a BETA feature. We don't expect the API to change, but there may
be some rare edge-cases. If you find any issues, please report them on
our bug tracker: https://github.com/pytorch/vision/issues?q=is:open+is:issue

First, a bit of setup code:
"""

# %%
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt


import torch
from torchvision.tv_tensors import KeyPoints
from torchvision.transforms import v2
from helpers import plot

plt.rcParams["figure.figsize"] = [10, 5]
plt.rcParams["savefig.bbox"] = "tight"

# if you change the seed, make sure that the transformed output
# still make sense
torch.manual_seed(0)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
orig_img = Image.open(Path('../assets') / 'pottery.jpg')

# %%
# Creating KeyPoints
# -------------------------------
# Key points are created by instantiating the
# :class:`~torchvision.tv_tensors.KeyPoints` class.


orig_pts = KeyPoints(
[
[
[445, 700], # nose
[320, 660],
[370, 660],
[420, 660], # left eye
[300, 620],
[420, 620], # left eyebrow
[475, 665],
[515, 665],
[555, 655], # right eye
[460, 625],
[560, 600], # right eyebrow
[370, 780],
[450, 760],
[540, 780],
[450, 820], # mouth
],
],
canvas_size=(orig_img.size[1], orig_img.size[0]),
)

plot([(orig_img, orig_pts)])

# %%
# Transforms illustrations
# ------------------------
#
# Using :class:`~torchvision.transforms.RandomRotation`:
rotater = v2.RandomRotation(degrees=(0, 180), expand=True)
rotated_imgs = [rotater((orig_img, orig_pts)) for _ in range(4)]
plot([(orig_img, orig_pts)] + rotated_imgs)

# %%
# Using :class:`~torchvision.transforms.Pad`:
padded_imgs_and_points = [
v2.Pad(padding=padding)(orig_img, orig_pts)
for padding in (30, 50, 100, 200)
]
plot([(orig_img, orig_pts)] + padded_imgs_and_points)

# %%
# Using :class:`~torchvision.transforms.Resize`:
resized_imgs = [
v2.Resize(size=size)(orig_img, orig_pts)
for size in (300, 500, 1000, orig_img.size)
]
plot([(orig_img, orig_pts)] + resized_imgs)

# %%
# Using :class:`~torchvision.transforms.RandomPerspective`:
perspective_transformer = v2.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img, orig_pts) for _ in range(4)]
plot([(orig_img, orig_pts)] + perspective_imgs)

# %%
# Using :class:`~torchvision.transforms.CenterCrop`:
center_crops_and_points = [
v2.CenterCrop(size=size)(orig_img, orig_pts)
for size in (300, 500, 1000, orig_img.size)
]
plot([(orig_img, orig_pts)] + center_crops_and_points)

# %%
# Using :class:`~torchvision.transforms.RandomRotation`:
rotater = v2.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater((orig_img, orig_pts)) for _ in range(4)]
plot([(orig_img, orig_pts)] + rotated_imgs)
Loading