-
Notifications
You must be signed in to change notification settings - Fork 32k
Improve compiled RT-DETR inference speed #33412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b458aa1
f85a80d
6f2da75
2f07ac2
8fe3ec5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| import os | ||
| import warnings | ||
| from dataclasses import dataclass | ||
| from functools import lru_cache, partial | ||
| from functools import lru_cache, partial, wraps | ||
| from pathlib import Path | ||
| from typing import Dict, List, Optional, Tuple, Union | ||
|
|
||
|
|
@@ -737,7 +737,9 @@ def multi_scale_deformable_attention( | |
| ) -> Tensor: | ||
| batch_size, _, num_heads, hidden_dim = value.shape | ||
| _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape | ||
| value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) | ||
| # Ignore copy | ||
| value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) | ||
|
|
||
| sampling_grids = 2 * sampling_locations - 1 | ||
| sampling_value_list = [] | ||
| for level_id, (height, width) in enumerate(value_spatial_shapes): | ||
|
|
@@ -849,6 +851,7 @@ def forward( | |
| position_embeddings: Optional[torch.Tensor] = None, | ||
| reference_points=None, | ||
| spatial_shapes=None, | ||
| spatial_shapes_list=None, | ||
| level_start_index=None, | ||
| output_attentions: bool = False, | ||
| ): | ||
|
|
@@ -858,7 +861,10 @@ def forward( | |
|
|
||
| batch_size, num_queries, _ = hidden_states.shape | ||
| batch_size, sequence_length, _ = encoder_hidden_states.shape | ||
| if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: | ||
|
|
||
| # Ignore copy | ||
| total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list) | ||
| if total_elements != sequence_length: | ||
| raise ValueError( | ||
| "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" | ||
| ) | ||
|
|
@@ -893,9 +899,12 @@ def forward( | |
| else: | ||
| raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") | ||
|
|
||
| if self.disable_custom_kernels: | ||
| # Ignore copy | ||
| if self.disable_custom_kernels or MultiScaleDeformableAttention is None: | ||
| # PyTorch implementation | ||
| output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) | ||
| output = multi_scale_deformable_attention( | ||
| value, spatial_shapes_list, sampling_locations, attention_weights | ||
| ) | ||
| else: | ||
| try: | ||
| # custom kernel | ||
|
|
@@ -909,7 +918,9 @@ def forward( | |
| ) | ||
| except Exception: | ||
| # PyTorch implementation | ||
| output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) | ||
| output = multi_scale_deformable_attention( | ||
| value, spatial_shapes_list, sampling_locations, attention_weights | ||
| ) | ||
| output = self.output_proj(output) | ||
|
|
||
| return output, attention_weights | ||
|
|
@@ -1064,6 +1075,7 @@ def forward( | |
| position_embeddings: Optional[torch.Tensor] = None, | ||
| reference_points=None, | ||
| spatial_shapes=None, | ||
| spatial_shapes_list=None, | ||
| level_start_index=None, | ||
| encoder_hidden_states: Optional[torch.Tensor] = None, | ||
| encoder_attention_mask: Optional[torch.Tensor] = None, | ||
|
|
@@ -1114,6 +1126,7 @@ def forward( | |
| position_embeddings=position_embeddings, | ||
| reference_points=reference_points, | ||
| spatial_shapes=spatial_shapes, | ||
| spatial_shapes_list=spatial_shapes_list, | ||
| level_start_index=level_start_index, | ||
| output_attentions=output_attentions, | ||
| ) | ||
|
|
@@ -1299,14 +1312,16 @@ def __init__(self, config: RTDetrConfig): | |
| self.pan_blocks.append(RTDetrCSPRepLayer(config)) | ||
|
|
||
| @staticmethod | ||
| def build_2d_sincos_position_embedding(width, height, embed_dim=256, temperature=10000.0): | ||
| grid_w = torch.arange(int(width), dtype=torch.float32) | ||
| grid_h = torch.arange(int(height), dtype=torch.float32) | ||
| def build_2d_sincos_position_embedding( | ||
| width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 | ||
| ): | ||
| grid_w = torch.arange(int(width), dtype=dtype, device=device) | ||
| grid_h = torch.arange(int(height), dtype=dtype, device=device) | ||
| grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") | ||
| if embed_dim % 4 != 0: | ||
| raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") | ||
| pos_dim = embed_dim // 4 | ||
| omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim | ||
| omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yonigozlan Is this the main reason for being the compiled version of RT-DETR has boosted up majorly in FP16? (Since the original omega was giving default float32 and needed to be transitioned to fp16 in inference?)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’m quite new to torch compile, so take this with a grain of salt, but from my understanding, the main reason for the speedup (both in fp32 and fp16) is that the model now has no graph breaks. This means no CPU/GPU transfers inside the model, and lets torch compile use CUDA graphs, reducing kernel launch overhead.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the kind explanation, I am also newbie of this compile logic so I just wanted to know for my curiosity.
I think this could be also possible, thanks 👍🏼 |
||
| omega = 1.0 / (temperature**omega) | ||
|
|
||
| out_w = grid_w.flatten()[..., None] @ omega[None] | ||
|
|
@@ -1372,8 +1387,13 @@ def forward( | |
| src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) | ||
| if self.training or self.eval_size is None: | ||
| pos_embed = self.build_2d_sincos_position_embedding( | ||
| width, height, self.encoder_hidden_dim, self.positional_encoding_temperature | ||
| ).to(src_flatten.device, src_flatten.dtype) | ||
| width, | ||
| height, | ||
| self.encoder_hidden_dim, | ||
| self.positional_encoding_temperature, | ||
| device=src_flatten.device, | ||
| dtype=src_flatten.dtype, | ||
| ) | ||
| else: | ||
| pos_embed = None | ||
|
|
||
|
|
@@ -1441,6 +1461,7 @@ def forward( | |
| position_embeddings=None, | ||
| reference_points=None, | ||
| spatial_shapes=None, | ||
| spatial_shapes_list=None, | ||
| level_start_index=None, | ||
| valid_ratios=None, | ||
| output_attentions=None, | ||
|
|
@@ -1512,6 +1533,7 @@ def forward( | |
| encoder_hidden_states=encoder_hidden_states, | ||
| reference_points=reference_points_input, | ||
| spatial_shapes=spatial_shapes, | ||
| spatial_shapes_list=spatial_shapes_list, | ||
| level_start_index=level_start_index, | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| output_attentions=output_attentions, | ||
|
|
@@ -1575,6 +1597,27 @@ def forward( | |
| ) | ||
|
|
||
|
|
||
| def compile_compatible_lru_cache(*lru_args, **lru_kwargs): | ||
| def decorator(func): | ||
| @wraps(func) | ||
| def wrapper(self, *args, **kwargs): | ||
| if not torch.compiler.is_compiling(): | ||
| # Cache the function only if the model is not being compiled | ||
| # check if the function is already cached, otherwise create it | ||
| if not hasattr(self, f"_cached_{func.__name__}"): | ||
| self.__setattr__( | ||
| f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self)) | ||
| ) | ||
| return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs) | ||
| else: | ||
| # Otherwise, just call the original function | ||
| return func(self, *args, **kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
| return decorator | ||
|
|
||
|
|
||
| @add_start_docstrings( | ||
| """ | ||
| RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. | ||
|
|
@@ -1626,7 +1669,7 @@ def __init__(self, config: RTDetrConfig): | |
|
|
||
| # init encoder output anchors and valid_mask | ||
| if config.anchor_image_size: | ||
| self.anchors, self.valid_mask = self.generate_anchors() | ||
| self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) | ||
|
|
||
| # Create decoder input projection layers | ||
| # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 | ||
|
|
@@ -1669,12 +1712,8 @@ def unfreeze_backbone(self): | |
| for param in self.backbone.parameters(): | ||
| param.requires_grad_(True) | ||
|
|
||
| @lru_cache(maxsize=32) | ||
| def generate_anchors(self, spatial_shapes=None, grid_size=0.05): | ||
| # We always generate anchors in float32 to preserve equivalence between | ||
| # dynamic and static anchor inference | ||
| dtype = torch.float32 | ||
|
|
||
| @compile_compatible_lru_cache(maxsize=32) | ||
| def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): | ||
| if spatial_shapes is None: | ||
| spatial_shapes = [ | ||
| [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] | ||
|
|
@@ -1683,10 +1722,12 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05): | |
| anchors = [] | ||
| for level, (height, width) in enumerate(spatial_shapes): | ||
| grid_y, grid_x = torch.meshgrid( | ||
| torch.arange(end=height, dtype=dtype), torch.arange(end=width, dtype=dtype), indexing="ij" | ||
| torch.arange(end=height, dtype=dtype, device=device), | ||
| torch.arange(end=width, dtype=dtype, device=device), | ||
| indexing="ij", | ||
| ) | ||
| grid_xy = torch.stack([grid_x, grid_y], -1) | ||
| valid_wh = torch.tensor([width, height]).to(dtype) | ||
| valid_wh = torch.tensor([width, height], device=device).to(dtype) | ||
| grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh | ||
| wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) | ||
| anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) | ||
|
|
@@ -1826,7 +1867,7 @@ def forward( | |
| # Pass spatial_shapes as tuple to make it hashable and make sure | ||
| # lru_cache is working for generate_anchors() | ||
| spatial_shapes_tuple = tuple(spatial_shapes_list) | ||
| anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple) | ||
| anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) | ||
| else: | ||
| anchors, valid_mask = self.anchors, self.valid_mask | ||
|
|
||
|
|
@@ -1873,6 +1914,7 @@ def forward( | |
| encoder_attention_mask=attention_mask, | ||
| reference_points=init_reference_points, | ||
| spatial_shapes=spatial_shapes, | ||
| spatial_shapes_list=spatial_shapes_list, | ||
| level_start_index=level_start_index, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "Copy from" looks almost useless for this function with "Ignore copy" a few lines after 🙂
does it make sense to spread changes to other architectures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but I'm also working on making the same changes in deformable detr so the Ignore copy shouldn't be necessary soon. I'm getting some weird results with compiled deformable detr in fp16 though, which I really don't understand.