-
Notifications
You must be signed in to change notification settings - Fork 32.6k
Add SuperGlue model #29886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SuperGlue model #29886
Changes from all commits
dfcac01
081efbf
b51ef54
78ff459
169c77c
11389f7
ffd3d76
875c291
e0fd6b9
355e57e
38a86b3
186d5f4
6237a36
fb7c58b
18e1818
33a0774
f516bea
4d12216
ca54def
113d8bf
77948ac
dd95fdd
006c49a
7e24e8a
6ba3350
4c1b67d
ff6808e
b67ce1b
0da244f
9fda942
f5cacf9
6556963
b68ba3d
2120390
e1810fb
774b371
9a527bb
c0f9aac
ba6c533
9adcfa4
a733666
2b6f777
61332fa
769d94a
798f9fe
1d7269e
d40210b
5920131
a726617
83a6422
a910b85
90fab40
d02b8b9
3e94f1f
417ce26
0d46d64
b2b7264
dc15c4a
00b4e5a
1bfd289
fc53366
b5cd46d
db4507f
866866d
8510749
9f269d9
7173ed8
3d77a20
0e4ab19
a794ade
10886a6
6982f80
dba1701
c6c69fa
af51f8f
1f5df27
d9cd67f
933b72d
7b5e21d
ec824b7
2a95cef
3e48daf
bdae04c
5dcf6f8
9793f4e
e25ab9d
95fb28f
3e8c974
7569107
3f79ae4
f9861cd
d96b0a8
f5dce0d
e270250
8172032
4f1ce85
5f8c347
8b9219e
c3ca4a1
ed47ed0
e7f36c5
b053497
e9a0120
6269f49
e788a02
d92adca
d76d28c
7691b83
0447476
2f924d7
60b9d3d
13ae18d
c6b1177
fe22517
bc97e7f
163dd04
243b2fb
aac8b22
f2fbe8a
84bb579
0faa7d2
b04cf00
de6a1d5
538404e
0d2825d
4c28182
f01f206
8f165da
c3773a5
c31e8e0
63121d8
7e5ac79
c928128
2233c1f
c98e2e2
d899649
60c1536
061918f
ff94a08
8ea374d
e9c5483
5534cb4
f8f5dd7
9ede8fa
777ff09
732d964
55bfdb9
9c237c9
02fd468
d34af9e
ef6241b
a802ead
ea566d6
d4dbf2e
92e8894
7e2687c
8694cd5
2ace6a6
8289e25
b395999
bc0a691
4938770
fd6a01b
e906750
97cf682
0ad7f2d
63ff94f
53d37dd
fb7cdc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the MIT License; you may not use this file except in compliance with | ||
| the License. | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
|
|
||
| --> | ||
|
|
||
| # SuperGlue | ||
|
|
||
| ## Overview | ||
|
|
||
| The SuperGlue model was proposed in [SuperGlue: Learning Feature Matching with Graph Neural Networks](https://arxiv.org/abs/1911.11763) by Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz and Andrew Rabinovich. | ||
|
|
||
| This model consists of matching two sets of interest points detected in an image. Paired with the | ||
| [SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and | ||
| estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc. | ||
|
|
||
| The abstract from the paper is the following: | ||
|
|
||
| *This paper introduces SuperGlue, a neural network that matches two sets of local features by jointly finding correspondences | ||
| and rejecting non-matchable points. Assignments are estimated by solving a differentiable optimal transport problem, whose costs | ||
| are predicted by a graph neural network. We introduce a flexible context aggregation mechanism based on attention, enabling | ||
| SuperGlue to reason about the underlying 3D scene and feature assignments jointly. Compared to traditional, hand-designed heuristics, | ||
| our technique learns priors over geometric transformations and regularities of the 3D world through end-to-end training from image | ||
| pairs. SuperGlue outperforms other learned approaches and achieves state-of-the-art results on the task of pose estimation in | ||
| challenging real-world indoor and outdoor environments. The proposed method performs matching in real-time on a modern GPU and | ||
| can be readily integrated into modern SfM or SLAM systems. The code and trained weights are publicly available at this [URL](https://github.com/magicleap/SuperGluePretrainedNetwork).* | ||
|
|
||
| ## How to use | ||
|
|
||
| Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched. | ||
| The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding | ||
| matching scores. | ||
| ```python | ||
| from transformers import AutoImageProcessor, AutoModel | ||
| import torch | ||
| from PIL import Image | ||
| import requests | ||
|
|
||
| url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg" | ||
| image1 = Image.open(requests.get(url_image1, stream=True).raw) | ||
| url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg" | ||
| image_2 = Image.open(requests.get(url_image2, stream=True).raw) | ||
|
|
||
| images = [image1, image2] | ||
|
|
||
| processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") | ||
| model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor") | ||
|
|
||
| inputs = processor(images, return_tensors="pt") | ||
| with torch.no_grad(): | ||
| outputs = model(**inputs) | ||
| ``` | ||
|
|
||
| You can use the `post_process_keypoint_matching` method from the `SuperGlueImageProcessor` to get the keypoints and matches in a more readable format: | ||
|
|
||
| ```python | ||
| image_sizes = [[(image.height, image.width) for image in images]] | ||
| outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2) | ||
| for i, output in enumerate(outputs): | ||
| print("For the image pair", i) | ||
| for keypoint0, keypoint1, matching_score in zip( | ||
| output["keypoints0"], output["keypoints1"], output["matching_scores"] | ||
| ): | ||
| print( | ||
| f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}." | ||
| ) | ||
|
|
||
| ``` | ||
|
|
||
| From the outputs, you can visualize the matches between the two images using the following code: | ||
| ```python | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
|
|
||
| # Create side by side image | ||
| merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3)) | ||
| merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0 | ||
| merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0 | ||
| plt.imshow(merged_image) | ||
| plt.axis("off") | ||
|
|
||
| # Retrieve the keypoints and matches | ||
| output = outputs[0] | ||
| keypoints0 = output["keypoints0"] | ||
| keypoints1 = output["keypoints1"] | ||
| matching_scores = output["matching_scores"] | ||
| keypoints0_x, keypoints0_y = keypoints0[:, 0].numpy(), keypoints0[:, 1].numpy() | ||
| keypoints1_x, keypoints1_y = keypoints1[:, 0].numpy(), keypoints1[:, 1].numpy() | ||
|
|
||
| # Plot the matches | ||
| for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( | ||
| keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, matching_scores | ||
| ): | ||
| plt.plot( | ||
| [keypoint0_x, keypoint1_x + image1.width], | ||
| [keypoint0_y, keypoint1_y], | ||
| color=plt.get_cmap("RdYlGn")(matching_score.item()), | ||
| alpha=0.9, | ||
| linewidth=0.5, | ||
| ) | ||
| plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) | ||
| plt.scatter(keypoint1_x + image1.width, keypoint1_y, c="black", s=2) | ||
|
|
||
| # Save the plot | ||
| plt.savefig("matched_image.png", dpi=300, bbox_inches='tight') | ||
| plt.close() | ||
|
Comment on lines
+81
to
+115
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you think it would make sense to add this to the image processor / processor ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah a method sounds good
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought a bit about this but I think it depends on whether you want to put visualization forward in the library or not. Here in this example we assume only a pair of images, but as a method in the processor, should it handle multiple pairs like other methods ? If so, should we visualize the pairs individually / all together ? In terms of plotting, should we force the template we have here or allow some customization ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be considered as resolved ?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we could have a soft dependency as well |
||
| ``` | ||
qubvel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|  | ||
|
|
||
| This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). | ||
| The original code can be found [here](https://github.com/magicleap/SuperGluePretrainedNetwork). | ||
|
|
||
| ## SuperGlueConfig | ||
|
|
||
| [[autodoc]] SuperGlueConfig | ||
|
|
||
| ## SuperGlueImageProcessor | ||
|
|
||
| [[autodoc]] SuperGlueImageProcessor | ||
|
|
||
| - preprocess | ||
|
|
||
| ## SuperGlueForKeypointMatching | ||
|
|
||
| [[autodoc]] SuperGlueForKeypointMatching | ||
|
|
||
| - forward | ||
| - post_process_keypoint_matching | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,6 +234,7 @@ | |
| squeezebert, | ||
| stablelm, | ||
| starcoder2, | ||
| superglue, | ||
| superpoint, | ||
| swiftformer, | ||
| swin, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from ...utils import _LazyModule | ||
| from ...utils.import_utils import define_import_structure | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from .configuration_superglue import * | ||
| from .image_processing_superglue import * | ||
| from .modeling_superglue import * | ||
| else: | ||
| import sys | ||
|
|
||
| _file = globals()["__file__"] | ||
| sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Uh oh!
There was an error while loading. Please reload this page.