diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index fd460e54d393..8e9faa3e4e04 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -628,6 +628,10 @@ def _concat_early_stopped_outputs( matching_scores, ): early_stops_indices = torch.stack(early_stops_indices) + # Rearrange tensors to have the same order as the input batch + ids = torch.arange(early_stops_indices.shape[0]) + order_indices = early_stops_indices[ids] + early_stops_indices = early_stops_indices[order_indices] matches, final_pruned_keypoints_indices = ( pad_sequence(tensor, batch_first=True, padding_value=-1) for tensor in [matches, final_pruned_keypoints_indices] diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index 64c36f21fef9..29441344c9cd 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -786,6 +786,10 @@ def _concat_early_stopped_outputs( matching_scores, ): early_stops_indices = torch.stack(early_stops_indices) + # Rearrange tensors to have the same order as the input batch + ids = torch.arange(early_stops_indices.shape[0]) + order_indices = early_stops_indices[ids] + early_stops_indices = early_stops_indices[order_indices] matches, final_pruned_keypoints_indices = ( pad_sequence(tensor, batch_first=True, padding_value=-1) for tensor in [matches, final_pruned_keypoints_indices] diff --git a/tests/models/lightglue/test_modeling_lightglue.py b/tests/models/lightglue/test_modeling_lightglue.py index 17276f1cdefd..9342b9a58fb8 100644 --- a/tests/models/lightglue/test_modeling_lightglue.py +++ b/tests/models/lightglue/test_modeling_lightglue.py @@ -331,24 +331,13 @@ def test_inference(self): predicted_matches_values1 = outputs.matches[1, 0, 10:30] predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30] - expected_number_of_matches0 = 140 + expected_number_of_matches0 = 866 expected_matches_values0 = torch.tensor( - [14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], - dtype=torch.int64, - device=torch_device, - ) - expected_matching_scores_values0 = torch.tensor( - [0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583], - device=torch_device, - ) - - expected_number_of_matches1 = 866 - expected_matches_values1 = torch.tensor( [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=torch.int64, device=torch_device, ) - expected_matching_scores_values1 = torch.tensor( + expected_matching_scores_values0 = torch.tensor( [ 0.6188,0.7817,0.5686,0.9353,0.9801,0.9193,0.8632,0.9111,0.9821,0.5496, 0.9906,0.8682,0.9679,0.9914,0.9318,0.1910,0.9669,0.3240,0.9971,0.9923, @@ -356,6 +345,17 @@ def test_inference(self): device=torch_device ) # fmt:skip + expected_number_of_matches1 = 140 + expected_matches_values1 = torch.tensor( + [14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], + dtype=torch.int64, + device=torch_device, + ) + expected_matching_scores_values1 = torch.tensor( + [0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583], + device=torch_device, + ) + # expected_early_stopping_layer = 2 # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item() # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer) @@ -375,7 +375,6 @@ def test_inference(self): Such CUDA inconsistencies can be found [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) """ - self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4) self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4) self.assertTrue( @@ -590,3 +589,28 @@ def test_inference_without_early_stop_and_keypoint_pruning(self): ) self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4) self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4) + + @slow + def test_inference_order_with_early_stop(self): + model = LightGlueForKeypointMatching.from_pretrained( + "ETH-CVG/lightglue_superpoint", attn_implementation="eager" + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + # [[image2, image0], [image1, image1]] -> [[image2, image0], [image2, image0], [image1, image1]] + images = [images[0]] + images # adding a 3rd pair to test batching with early stopping + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches_pair0 = torch.sum(outputs.matches[0][0] != -1).item() + predicted_number_of_matches_pair1 = torch.sum(outputs.matches[1][0] != -1).item() + predicted_number_of_matches_pair2 = torch.sum(outputs.matches[2][0] != -1).item() + + # pair 0 and 1 are the same, so should have the same number of matches + # pair 2 is [image1, image1] so should have more matches than first two pairs + # This ensures that early stopping does not affect the order of the outputs + # See : https://huggingface.co/ETH-CVG/lightglue_superpoint/discussions/6 + # The bug made the pairs switch order when early stopping was activated + self.assertTrue(predicted_number_of_matches_pair0 == predicted_number_of_matches_pair1) + self.assertTrue(predicted_number_of_matches_pair0 < predicted_number_of_matches_pair2)