-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Optim deformable detr #33600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Optim deformable detr #33600
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
1224075
optimize deformable detr
yonigozlan 205d1a0
fix copies
yonigozlan 6902c7b
remove deformable_detr_basline
yonigozlan b1c23db
fix hardcoded float16 and .float()
yonigozlan 6144271
[run slow] deformable-detr,grounding-dino,mask2former,oneformer,rt-detr
yonigozlan 0b16150
[run slow] deformable_detr,grounding_dino,mask2former,oneformer,rt_detr
yonigozlan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -523,8 +523,8 @@ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=N | |||||
| def forward(self, pixel_values, pixel_mask): | ||||||
| if pixel_mask is None: | ||||||
| raise ValueError("No pixel mask provided") | ||||||
| y_embed = pixel_mask.cumsum(1, dtype=torch.float32) | ||||||
| x_embed = pixel_mask.cumsum(2, dtype=torch.float32) | ||||||
| y_embed = pixel_mask.cumsum(1, dtype=torch.float16) | ||||||
| x_embed = pixel_mask.cumsum(2, dtype=torch.float16) | ||||||
| if self.normalize: | ||||||
| eps = 1e-6 | ||||||
| y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale | ||||||
|
|
@@ -580,11 +580,14 @@ def build_position_encoding(config): | |||||
|
|
||||||
|
|
||||||
| def multi_scale_deformable_attention( | ||||||
| value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor | ||||||
| value: Tensor, | ||||||
| value_spatial_shapes: Union[Tensor, List[Tuple]], | ||||||
| sampling_locations: Tensor, | ||||||
| attention_weights: Tensor, | ||||||
| ) -> Tensor: | ||||||
| batch_size, _, num_heads, hidden_dim = value.shape | ||||||
| _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape | ||||||
| value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) | ||||||
| value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) | ||||||
| sampling_grids = 2 * sampling_locations - 1 | ||||||
| sampling_value_list = [] | ||||||
| for level_id, (height, width) in enumerate(value_spatial_shapes): | ||||||
|
|
@@ -695,6 +698,7 @@ def forward( | |||||
| position_embeddings: Optional[torch.Tensor] = None, | ||||||
| reference_points=None, | ||||||
| spatial_shapes=None, | ||||||
| spatial_shapes_list=None, | ||||||
| level_start_index=None, | ||||||
| output_attentions: bool = False, | ||||||
| ): | ||||||
|
|
@@ -704,7 +708,8 @@ def forward( | |||||
|
|
||||||
| batch_size, num_queries, _ = hidden_states.shape | ||||||
| batch_size, sequence_length, _ = encoder_hidden_states.shape | ||||||
| if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: | ||||||
| total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list) | ||||||
|
||||||
| total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list) | |
| total_elements = sum(height * width for height, width in spatial_shapes_list) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering, why should we use float16 instead of float32? shouldn't it be
pixel_values.dtype?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops absolutely, thanks for catching that.