Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
17d50e8
feat: Added int conversion and unwrapping
Aug 29, 2024
285c465
test: added tests for post_process_keypoint_detection of SuperPointIm…
sbucaille Aug 30, 2024
2efe61b
docs: changed docs to include post_process_keypoint_detection method …
sbucaille Aug 30, 2024
a77b870
test: changed test to not depend on SuperPointModel forward
sbucaille Aug 30, 2024
2ab79cd
test: added missing require_torch decorator
sbucaille Aug 30, 2024
419ae5d
docs: changed pyplot parameters for the keypoints to be more visible …
sbucaille Aug 30, 2024
39b32a2
tests: changed import torch location to make test_flax and test_tf
sbucaille Aug 30, 2024
144e09a
Revert "tests: changed import torch location to make test_flax and te…
sbucaille Aug 30, 2024
21dbdfc
tests: fixed import
sbucaille Aug 30, 2024
389b154
chore: applied suggestions from code review
sbucaille Sep 1, 2024
b7d672e
tests: fixed import
sbucaille Sep 1, 2024
f5d7311
tests: fixed import (bis)
sbucaille Sep 1, 2024
d89d385
tests: fixed import (ter)
sbucaille Sep 1, 2024
f9e1141
feat: added choice of type for target_size and changed tests accordingly
sbucaille Sep 1, 2024
32a2e96
docs: updated code snippet to reflect the addition of target size typ…
sbucaille Sep 1, 2024
560194e
tests: fixed imports (...)
Sep 2, 2024
2d28aba
tests: fixed imports (...)
Sep 2, 2024
bd23baa
style: formatting file
Sep 2, 2024
5bb0baf
docs: fixed typo from image[0] to image.size[0]
sbucaille Sep 2, 2024
ed28314
docs: added output image and fixed some tests
sbucaille Sep 5, 2024
192448d
Update docs/source/en/model_doc/superpoint.md
sbucaille Oct 2, 2024
e89af7f
fix: included SuperPointKeypointDescriptionOutput in TYPE_CHECKING if…
sbucaille Oct 2, 2024
4e77a4f
docs: changed SuperPoint's docs to print output instead of just acces…
sbucaille Oct 2, 2024
e9b642a
style: applied make style
sbucaille Oct 2, 2024
e085861
docs: added missing output type and precision in docstring of post_pr…
sbucaille Oct 3, 2024
9127545
perf: deleted loop to perform keypoint conversion in one statement
sbucaille Oct 3, 2024
1ffa465
fix: moved keypoint conversion at the end of model forward
sbucaille Oct 3, 2024
b0d25a3
docs: changed SuperPointInterestPointDecoder to SuperPointKeypointDec…
sbucaille Oct 3, 2024
1fb5705
fix: changed type hint
sbucaille Oct 3, 2024
13cb7e5
refactor: removed unnecessary brackets
sbucaille Oct 4, 2024
eb6a5aa
revert: SuperPointKeypointDecoder to SuperPointInterestPointDecoder
sbucaille Oct 4, 2024
4c34d75
Update docs/source/en/model_doc/superpoint.md
sbucaille Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions docs/source/en/model_doc/superpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,32 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup

inputs = processor(images, return_tensors="pt")
outputs = model(**inputs)

for i in range(len(images)):
image_mask = outputs.mask[i]
image_indices = torch.nonzero(image_mask).squeeze()
image_keypoints = outputs.keypoints[i][image_indices]
image_scores = outputs.scores[i][image_indices]
image_descriptors = outputs.descriptors[i][image_indices]
image_sizes = [(image.height, image.width) for image in images]
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)

for output in outputs:
for keypoints, scores, descriptors in zip(output["keypoints"], output["scores"], output["descriptors"]):
print(f"Keypoints: {keypoints}")
print(f"Scores: {scores}")
print(f"Descriptors: {descriptors}")
```

You can then print the keypoints on the image to visualize the result :
You can then print the keypoints on the image of your choice to visualize the result:
```python
import cv2
for keypoint, score in zip(image_keypoints, image_scores):
keypoint_x, keypoint_y = int(keypoint[0].item()), int(keypoint[1].item())
color = tuple([score.item() * 255] * 3)
image = cv2.circle(image, (keypoint_x, keypoint_y), 2, color)
cv2.imwrite("output_image.png", image)
import matplotlib.pyplot as plt

plt.axis("off")
plt.imshow(image_1)
plt.scatter(
outputs[0]["keypoints"][:, 0],
outputs[0]["keypoints"][:, 1],
c=outputs[0]["scores"] * 100,
s=outputs[0]["scores"] * 50,
alpha=0.8
)
plt.savefig(f"output_image.png")
```
![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/ZtFmphEhx8tcbEQqOolyE.png)

This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
The original code can be found [here](https://github.com/magicleap/SuperPointPretrainedNetwork).
Expand All @@ -123,6 +131,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] SuperPointImageProcessor

- preprocess
- post_process_keypoint_detection

## SuperPointForKeypointDetection

Expand Down
59 changes: 57 additions & 2 deletions src/transformers/models/superpoint/image_processing_superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
"""Image processor class for SuperPoint."""

from typing import Dict, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import numpy as np

from ... import is_vision_available
from ... import is_torch_available, is_vision_available
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
Expand All @@ -32,6 +32,12 @@
from ...utils import TensorType, logging, requires_backends


if is_torch_available():
import torch

if TYPE_CHECKING:
from .modeling_superpoint import SuperPointKeypointDescriptionOutput

if is_vision_available():
import PIL

Expand Down Expand Up @@ -270,3 +276,52 @@ def preprocess(
data = {"pixel_values": images}

return BatchFeature(data=data, tensor_type=return_tensors)

def post_process_keypoint_detection(
self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, List[Tuple]]
) -> List[Dict[str, "torch.Tensor"]]:
"""
Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors
with coordinates absolute to the original image sizes.

Args:
outputs ([`SuperPointKeypointDescriptionOutput`]):
Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. This must be the original
image size (before any processing).
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according
to target_sizes, scores and descriptors for an image in the batch as predicted by the model.
"""
if len(outputs.mask) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")

if isinstance(target_sizes, List):
image_sizes = torch.tensor(target_sizes)
else:
if target_sizes.shape[1] != 2:
raise ValueError(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
image_sizes = target_sizes

# Flip the image sizes to (width, height) and convert keypoints to absolute coordinates
image_sizes = torch.flip(image_sizes, [1])
masked_keypoints = outputs.keypoints * image_sizes[:, None]

# Convert masked_keypoints to int
masked_keypoints = masked_keypoints.to(torch.int32)

results = []
for image_mask, keypoints, scores, descriptors in zip(
outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors
):
indices = torch.nonzero(image_mask).squeeze(1)
keypoints = keypoints[indices]
scores = scores[indices]
descriptors = descriptors[indices]
results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors})

return results
10 changes: 8 additions & 2 deletions src/transformers/models/superpoint/modeling_superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ def _get_pixel_scores(self, encoded: torch.Tensor) -> torch.Tensor:
return scores

def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation"""
"""
Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation.
The keypoints are in the form of relative (x, y) coordinates.
"""
_, height, width = scores.shape

# Threshold keypoints by score value
Expand Down Expand Up @@ -447,7 +450,7 @@ def forward(

pixel_values = self.extract_one_channel_pixel_values(pixel_values)

batch_size = pixel_values.shape[0]
batch_size, _, height, width = pixel_values.shape

encoder_outputs = self.encoder(
pixel_values,
Expand Down Expand Up @@ -485,6 +488,9 @@ def forward(
descriptors[i, : _descriptors.shape[0]] = _descriptors
mask[i, : _scores.shape[0]] = 1

# Convert to relative coordinates
keypoints = keypoints / torch.tensor([width, height], device=keypoints.device)

hidden_states = encoder_outputs[1] if output_hidden_states else None
if not return_dict:
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
Expand Down
54 changes: 53 additions & 1 deletion tests/models/superpoint/test_image_processing_superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@
import numpy as np

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available

from ...test_image_processing_common import (
ImageProcessingTestMixin,
prepare_image_inputs,
)


if is_torch_available():
import torch

from transformers.models.superpoint.modeling_superpoint import SuperPointKeypointDescriptionOutput

if is_vision_available():
from transformers import SuperPointImageProcessor

Expand Down Expand Up @@ -70,6 +75,23 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
torchify=torchify,
)

def prepare_keypoint_detection_output(self, pixel_values):
max_number_keypoints = 50
batch_size = len(pixel_values)
mask = torch.zeros((batch_size, max_number_keypoints))
keypoints = torch.zeros((batch_size, max_number_keypoints, 2))
scores = torch.zeros((batch_size, max_number_keypoints))
descriptors = torch.zeros((batch_size, max_number_keypoints, 16))
for i in range(batch_size):
random_number_keypoints = np.random.randint(0, max_number_keypoints)
mask[i, :random_number_keypoints] = 1
keypoints[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 2))
scores[i, :random_number_keypoints] = torch.rand((random_number_keypoints,))
descriptors[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 16))
return SuperPointKeypointDescriptionOutput(
loss=None, keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, hidden_states=None
)


@require_torch
@require_vision
Expand Down Expand Up @@ -110,3 +132,33 @@ def test_input_image_properly_converted_to_grayscale(self):
pre_processed_images = image_processor.preprocess(image_inputs)
for image in pre_processed_images["pixel_values"]:
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))

@require_torch
def test_post_processing_keypoint_detection(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)

def check_post_processed_output(post_processed_output, image_size):
for post_processed_output, image_size in zip(post_processed_output, image_size):
self.assertTrue("keypoints" in post_processed_output)
self.assertTrue("descriptors" in post_processed_output)
self.assertTrue("scores" in post_processed_output)
keypoints = post_processed_output["keypoints"]
all_below_image_size = torch.all(keypoints[:, 0] <= image_size[1]) and torch.all(
keypoints[:, 1] <= image_size[0]
)
all_above_zero = torch.all(keypoints[:, 0] >= 0) and torch.all(keypoints[:, 1] >= 0)
self.assertTrue(all_below_image_size)
self.assertTrue(all_above_zero)

tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)

check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)

tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes)

check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
10 changes: 6 additions & 4 deletions tests/models/superpoint/test_modeling_superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_inference(self):
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
expected_number_keypoints_image0 = 567
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why the number of keypoints changed :)

expected_number_keypoints_image0 = 568
expected_number_keypoints_image1 = 830
expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1)
expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2))
Expand All @@ -275,11 +275,13 @@ def test_inference(self):
self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape)
self.assertEqual(outputs.scores.shape, expected_scores_shape)
self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape)
expected_keypoints_image0_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]).to(torch_device)
expected_keypoints_image0_values = torch.tensor([[0.75, 0.0188], [0.7719, 0.0188], [0.7641, 0.0333]]).to(
torch_device
)
expected_scores_image0_values = torch.tensor(
[0.0064, 0.0137, 0.0589, 0.0723, 0.5166, 0.0174, 0.1515, 0.2054, 0.0334]
[0.0064, 0.0139, 0.0591, 0.0727, 0.5170, 0.0175, 0.1526, 0.2057, 0.0335]
).to(torch_device)
expected_descriptors_image0_value = torch.tensor(-0.1096).to(torch_device)
expected_descriptors_image0_value = torch.tensor(-0.1095).to(torch_device)
predicted_keypoints_image0_values = outputs.keypoints[0, :3]
predicted_scores_image0_values = outputs.scores[0, :9]
predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]
Expand Down