diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index ab83a81f5067..4e32434901cd 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -18,7 +18,7 @@ import os import warnings from dataclasses import dataclass -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -737,7 +737,9 @@ def multi_scale_deformable_attention( ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) + # Ignore copy + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): @@ -849,6 +851,7 @@ def forward( position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -858,7 +861,10 @@ def forward( batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + + # Ignore copy + total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list) + if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" ) @@ -893,9 +899,12 @@ def forward( else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - if self.disable_custom_kernels: + # Ignore copy + if self.disable_custom_kernels or MultiScaleDeformableAttention is None: # PyTorch implementation - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = multi_scale_deformable_attention( + value, spatial_shapes_list, sampling_locations, attention_weights + ) else: try: # custom kernel @@ -909,7 +918,9 @@ def forward( ) except Exception: # PyTorch implementation - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = multi_scale_deformable_attention( + value, spatial_shapes_list, sampling_locations, attention_weights + ) output = self.output_proj(output) return output, attention_weights @@ -1064,6 +1075,7 @@ def forward( position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, @@ -1114,6 +1126,7 @@ def forward( position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1299,14 +1312,16 @@ def __init__(self, config: RTDetrConfig): self.pan_blocks.append(RTDetrCSPRepLayer(config)) @staticmethod - def build_2d_sincos_position_embedding(width, height, embed_dim=256, temperature=10000.0): - grid_w = torch.arange(int(width), dtype=torch.float32) - grid_h = torch.arange(int(height), dtype=torch.float32) + def build_2d_sincos_position_embedding( + width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ): + grid_w = torch.arange(int(width), dtype=dtype, device=device) + grid_h = torch.arange(int(height), dtype=dtype, device=device) grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") if embed_dim % 4 != 0: raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") pos_dim = embed_dim // 4 - omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim omega = 1.0 / (temperature**omega) out_w = grid_w.flatten()[..., None] @ omega[None] @@ -1372,8 +1387,13 @@ def forward( src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) if self.training or self.eval_size is None: pos_embed = self.build_2d_sincos_position_embedding( - width, height, self.encoder_hidden_dim, self.positional_encoding_temperature - ).to(src_flatten.device, src_flatten.dtype) + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + device=src_flatten.device, + dtype=src_flatten.dtype, + ) else: pos_embed = None @@ -1441,6 +1461,7 @@ def forward( position_embeddings=None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, valid_ratios=None, output_attentions=None, @@ -1512,6 +1533,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, reference_points=reference_points_input, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, @@ -1575,6 +1597,27 @@ def forward( ) +def compile_compatible_lru_cache(*lru_args, **lru_kwargs): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not torch.compiler.is_compiling(): + # Cache the function only if the model is not being compiled + # check if the function is already cached, otherwise create it + if not hasattr(self, f"_cached_{func.__name__}"): + self.__setattr__( + f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self)) + ) + return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs) + else: + # Otherwise, just call the original function + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + @add_start_docstrings( """ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. @@ -1626,7 +1669,7 @@ def __init__(self, config: RTDetrConfig): # init encoder output anchors and valid_mask if config.anchor_image_size: - self.anchors, self.valid_mask = self.generate_anchors() + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) # Create decoder input projection layers # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 @@ -1669,12 +1712,8 @@ def unfreeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad_(True) - @lru_cache(maxsize=32) - def generate_anchors(self, spatial_shapes=None, grid_size=0.05): - # We always generate anchors in float32 to preserve equivalence between - # dynamic and static anchor inference - dtype = torch.float32 - + @compile_compatible_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): if spatial_shapes is None: spatial_shapes = [ [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] @@ -1683,10 +1722,12 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05): anchors = [] for level, (height, width) in enumerate(spatial_shapes): grid_y, grid_x = torch.meshgrid( - torch.arange(end=height, dtype=dtype), torch.arange(end=width, dtype=dtype), indexing="ij" + torch.arange(end=height, dtype=dtype, device=device), + torch.arange(end=width, dtype=dtype, device=device), + indexing="ij", ) grid_xy = torch.stack([grid_x, grid_y], -1) - valid_wh = torch.tensor([width, height]).to(dtype) + valid_wh = torch.tensor([width, height], device=device).to(dtype) grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) @@ -1826,7 +1867,7 @@ def forward( # Pass spatial_shapes as tuple to make it hashable and make sure # lru_cache is working for generate_anchors() spatial_shapes_tuple = tuple(spatial_shapes_list) - anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) else: anchors, valid_mask = self.anchors, self.valid_mask @@ -1873,6 +1914,7 @@ def forward( encoder_attention_mask=attention_mask, reference_points=init_reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, output_hidden_states=output_hidden_states,