Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
17d50e8
feat: Added int conversion and unwrapping
Aug 29, 2024
285c465
test: added tests for post_process_keypoint_detection of SuperPointIm…
sbucaille Aug 30, 2024
2efe61b
docs: changed docs to include post_process_keypoint_detection method …
sbucaille Aug 30, 2024
a77b870
test: changed test to not depend on SuperPointModel forward
sbucaille Aug 30, 2024
2ab79cd
test: added missing require_torch decorator
sbucaille Aug 30, 2024
419ae5d
docs: changed pyplot parameters for the keypoints to be more visible …
sbucaille Aug 30, 2024
39b32a2
tests: changed import torch location to make test_flax and test_tf
sbucaille Aug 30, 2024
144e09a
Revert "tests: changed import torch location to make test_flax and te…
sbucaille Aug 30, 2024
21dbdfc
tests: fixed import
sbucaille Aug 30, 2024
389b154
chore: applied suggestions from code review
sbucaille Sep 1, 2024
b7d672e
tests: fixed import
sbucaille Sep 1, 2024
f5d7311
tests: fixed import (bis)
sbucaille Sep 1, 2024
d89d385
tests: fixed import (ter)
sbucaille Sep 1, 2024
f9e1141
feat: added choice of type for target_size and changed tests accordingly
sbucaille Sep 1, 2024
32a2e96
docs: updated code snippet to reflect the addition of target size typ…
sbucaille Sep 1, 2024
560194e
tests: fixed imports (...)
Sep 2, 2024
2d28aba
tests: fixed imports (...)
Sep 2, 2024
bd23baa
style: formatting file
Sep 2, 2024
5bb0baf
docs: fixed typo from image[0] to image.size[0]
sbucaille Sep 2, 2024
ed28314
docs: added output image and fixed some tests
sbucaille Sep 5, 2024
192448d
Update docs/source/en/model_doc/superpoint.md
sbucaille Oct 2, 2024
e89af7f
fix: included SuperPointKeypointDescriptionOutput in TYPE_CHECKING if…
sbucaille Oct 2, 2024
4e77a4f
docs: changed SuperPoint's docs to print output instead of just acces…
sbucaille Oct 2, 2024
e9b642a
style: applied make style
sbucaille Oct 2, 2024
e085861
docs: added missing output type and precision in docstring of post_pr…
sbucaille Oct 3, 2024
9127545
perf: deleted loop to perform keypoint conversion in one statement
sbucaille Oct 3, 2024
1ffa465
fix: moved keypoint conversion at the end of model forward
sbucaille Oct 3, 2024
b0d25a3
docs: changed SuperPointInterestPointDecoder to SuperPointKeypointDec…
sbucaille Oct 3, 2024
1fb5705
fix: changed type hint
sbucaille Oct 3, 2024
13cb7e5
refactor: removed unnecessary brackets
sbucaille Oct 4, 2024
eb6a5aa
revert: SuperPointKeypointDecoder to SuperPointInterestPointDecoder
sbucaille Oct 4, 2024
4c34d75
Update docs/source/en/model_doc/superpoint.md
sbucaille Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions docs/source/en/model_doc/superpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,28 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup

inputs = processor(images, return_tensors="pt")
outputs = model(**inputs)
image_sizes = torch.tensor([image.size for image in images]).flip(1)
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)

for i in range(len(images)):
image_mask = outputs.mask[i]
image_indices = torch.nonzero(image_mask).squeeze()
image_keypoints = outputs.keypoints[i][image_indices]
image_scores = outputs.scores[i][image_indices]
image_descriptors = outputs.descriptors[i][image_indices]
for output in outputs:
keypoints = output["keypoints"]
scores = output["scores"]
descriptors = output["descriptors"]
Comment thread
qubvel marked this conversation as resolved.
Outdated
```

You can then print the keypoints on the image to visualize the result :
You can then print the keypoints on the image of your choice to visualize the result :
Comment thread
sbucaille marked this conversation as resolved.
Outdated
```python
import cv2
for keypoint, score in zip(image_keypoints, image_scores):
keypoint_x, keypoint_y = int(keypoint[0].item()), int(keypoint[1].item())
color = tuple([score.item() * 255] * 3)
image = cv2.circle(image, (keypoint_x, keypoint_y), 2, color)
cv2.imwrite("output_image.png", image)
import matplotlib.pyplot as plt
plt.axis("off")
Comment thread
sbucaille marked this conversation as resolved.
plt.imshow(image)
plt.scatter(
keypoints[:, 0],
keypoints[:, 1],
c=scores * 100,
s=scores * 50,
alpha=0.8
)
Comment thread
sbucaille marked this conversation as resolved.
Outdated
plt.savefig(f"output_image.png")
Comment thread
qubvel marked this conversation as resolved.
```

This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
Expand All @@ -123,6 +128,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] SuperPointImageProcessor

- preprocess
- post_process_keypoint_detection

## SuperPointForKeypointDetection

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from ... import is_vision_available
from ... import is_torch_available, is_vision_available
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
Expand All @@ -30,8 +30,12 @@
valid_images,
)
from ...utils import TensorType, logging, requires_backends
from .modeling_superpoint import SuperPointKeypointDescriptionOutput


if is_torch_available():
import torch

if is_vision_available():
import PIL

Expand Down Expand Up @@ -270,3 +274,46 @@ def preprocess(
data = {"pixel_values": images}

return BatchFeature(data=data, tensor_type=return_tensors)

def post_process_keypoint_detection(
self, outputs: SuperPointKeypointDescriptionOutput, target_sizes: torch.Tensor
):
Comment thread
qubvel marked this conversation as resolved.
Outdated
"""
Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors
with coordinates absolute to the original image sizes.

Args:
outputs ([`SuperPointKeypointDescriptionOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
Comment thread
qubvel marked this conversation as resolved.
Outdated
Tensor containing the size (h, w) of each image of the batch. This must be the original
image size (before any processing).
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints, scores and descriptors for
an image in the batch as predicted by the model.
"""
if len(outputs.mask) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
if target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")

masked_keypoints = outputs.keypoints.clone()

for keypoints, target_size in zip(masked_keypoints, target_sizes):
keypoints[:, 0] = keypoints[:, 0] * target_size[1]
keypoints[:, 1] = keypoints[:, 1] * target_size[0]

# Convert masked_keypoints to int
masked_keypoints = masked_keypoints.to(torch.int32)

results = []
for image_mask, keypoints, scores, descriptors in zip(
outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors
):
indices = torch.nonzero(image_mask).squeeze(1)
keypoints = keypoints[indices]
scores = scores[indices]
descriptors = descriptors[indices]
results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors})

return results
3 changes: 3 additions & 0 deletions src/transformers/models/superpoint/modeling_superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.
# Convert (y, x) to (x, y)
keypoints = torch.flip(keypoints, [1]).float()
Comment thread
qubvel marked this conversation as resolved.

# Convert to relative coordinates
keypoints = keypoints / torch.tensor([width, height], device=keypoints.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically a breaking change - but as this is more of a fix, I think it's OK


return keypoints, scores


Expand Down
47 changes: 45 additions & 2 deletions tests/models/superpoint/test_image_processing_superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@

import numpy as np

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available

from ...test_image_processing_common import (
ImageProcessingTestMixin,
prepare_image_inputs,
)


if is_torch_available():
import torch

if is_vision_available():
from transformers import SuperPointImageProcessor
from transformers.models.superpoint.modeling_superpoint import SuperPointKeypointDescriptionOutput


class SuperPointImageProcessingTester(unittest.TestCase):
Expand Down Expand Up @@ -70,6 +74,23 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
torchify=torchify,
)

def prepare_keypoint_detection_output(self, pixel_values):
max_number_keypoints = 50
batch_size = len(pixel_values)
mask = torch.zeros((batch_size, max_number_keypoints))
keypoints = torch.zeros((batch_size, max_number_keypoints, 2))
scores = torch.zeros((batch_size, max_number_keypoints))
descriptors = torch.zeros((batch_size, max_number_keypoints, 16))
for i in range(batch_size):
random_number_keypoints = np.random.randint(0, max_number_keypoints)
mask[i, :random_number_keypoints] = 1
keypoints[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 2))
scores[i, :random_number_keypoints] = torch.rand((random_number_keypoints,))
descriptors[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 16))
return SuperPointKeypointDescriptionOutput(
loss=None, keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, hidden_states=None
)


@require_torch
@require_vision
Expand Down Expand Up @@ -110,3 +131,25 @@ def test_input_image_properly_converted_to_grayscale(self):
pre_processed_images = image_processor.preprocess(image_inputs)
for image in pre_processed_images["pixel_values"]:
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))

@slow
Comment thread
qubvel marked this conversation as resolved.
Outdated
@require_torch
def test_post_processing_keypoint_detection(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, image_sizes)

self.assertTrue(len(post_processed_outputs) == self.image_processor_tester.batch_size)
for post_processed_output, image_size in zip(post_processed_outputs, image_sizes):
self.assertTrue("keypoints" in post_processed_output)
self.assertTrue("descriptors" in post_processed_output)
self.assertTrue("scores" in post_processed_output)
keypoints = post_processed_output["keypoints"]
all_below_image_size = torch.all(keypoints[:, 0] <= image_size[1]) and torch.all(
keypoints[:, 1] <= image_size[0]
)
all_above_zero = torch.all(keypoints[:, 0] >= 0) and torch.all(keypoints[:, 1] >= 0)
self.assertTrue(all_below_image_size and all_above_zero)