Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import warnings
from dataclasses import dataclass
from functools import lru_cache, partial
from functools import lru_cache, partial, wraps
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -737,7 +737,9 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
# Ignore copy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "Copy from" looks almost useless for this function with "Ignore copy" a few lines after 🙂
does it make sense to spread changes to other architectures?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but I'm also working on making the same changes in deformable detr so the Ignore copy shouldn't be necessary soon. I'm getting some weird results with compiled deformable detr in fp16 though, which I really don't understand.

value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)

sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down Expand Up @@ -849,6 +851,7 @@ def forward(
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -858,7 +861,10 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:

# Ignore copy
total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
Expand Down Expand Up @@ -893,9 +899,12 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

if self.disable_custom_kernels:
# Ignore copy
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
else:
try:
# custom kernel
Expand All @@ -909,7 +918,9 @@ def forward(
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
output = self.output_proj(output)

return output, attention_weights
Expand Down Expand Up @@ -1064,6 +1075,7 @@ def forward(
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1114,6 +1126,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1299,14 +1312,16 @@ def __init__(self, config: RTDetrConfig):
self.pan_blocks.append(RTDetrCSPRepLayer(config))

@staticmethod
def build_2d_sincos_position_embedding(width, height, embed_dim=256, temperature=10000.0):
grid_w = torch.arange(int(width), dtype=torch.float32)
grid_h = torch.arange(int(height), dtype=torch.float32)
def build_2d_sincos_position_embedding(
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
):
grid_w = torch.arange(int(width), dtype=dtype, device=device)
grid_h = torch.arange(int(height), dtype=dtype, device=device)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
if embed_dim % 4 != 0:
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
Copy link
Contributor

@SangbumChoi SangbumChoi Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yonigozlan Is this the main reason for being the compiled version of RT-DETR has boosted up majorly in FP16? (Since the original omega was giving default float32 and needed to be transitioned to fp16 in inference?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m quite new to torch compile, so take this with a grain of salt, but from my understanding, the main reason for the speedup (both in fp32 and fp16) is that the model now has no graph breaks. This means no CPU/GPU transfers inside the model, and lets torch compile use CUDA graphs, reducing kernel launch overhead.
I think the main boost in fp16 comes from Tensor Cores, which are optimized for half-precision and make GPU operations faster. In fact, GPU operations were already faster in the compiled models in the current version, but the gains were overshadowed by CPU/GPU transfer overhead. So I don't think this change of omega to fp16 made all the difference, but it's more a global effect of no graph breaks + Tensor Cores optimized for fp16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the kind explanation, I am also newbie of this compile logic so I just wanted to know for my curiosity.

model now has no graph breaks

I think this could be also possible, thanks 👍🏼

omega = 1.0 / (temperature**omega)

out_w = grid_w.flatten()[..., None] @ omega[None]
Expand Down Expand Up @@ -1372,8 +1387,13 @@ def forward(
src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
if self.training or self.eval_size is None:
pos_embed = self.build_2d_sincos_position_embedding(
width, height, self.encoder_hidden_dim, self.positional_encoding_temperature
).to(src_flatten.device, src_flatten.dtype)
width,
height,
self.encoder_hidden_dim,
self.positional_encoding_temperature,
device=src_flatten.device,
dtype=src_flatten.dtype,
)
else:
pos_embed = None

Expand Down Expand Up @@ -1441,6 +1461,7 @@ def forward(
position_embeddings=None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
Expand Down Expand Up @@ -1512,6 +1533,7 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1575,6 +1597,27 @@ def forward(
)


def compile_compatible_lru_cache(*lru_args, **lru_kwargs):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not torch.compiler.is_compiling():
# Cache the function only if the model is not being compiled
# check if the function is already cached, otherwise create it
if not hasattr(self, f"_cached_{func.__name__}"):
self.__setattr__(
f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self))
)
return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs)
else:
# Otherwise, just call the original function
return func(self, *args, **kwargs)

return wrapper

return decorator


@add_start_docstrings(
"""
RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
Expand Down Expand Up @@ -1626,7 +1669,7 @@ def __init__(self, config: RTDetrConfig):

# init encoder output anchors and valid_mask
if config.anchor_image_size:
self.anchors, self.valid_mask = self.generate_anchors()
self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)

# Create decoder input projection layers
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
Expand Down Expand Up @@ -1669,12 +1712,8 @@ def unfreeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad_(True)

@lru_cache(maxsize=32)
def generate_anchors(self, spatial_shapes=None, grid_size=0.05):
# We always generate anchors in float32 to preserve equivalence between
# dynamic and static anchor inference
dtype = torch.float32

@compile_compatible_lru_cache(maxsize=32)
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
if spatial_shapes is None:
spatial_shapes = [
[int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
Expand All @@ -1683,10 +1722,12 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05):
anchors = []
for level, (height, width) in enumerate(spatial_shapes):
grid_y, grid_x = torch.meshgrid(
torch.arange(end=height, dtype=dtype), torch.arange(end=width, dtype=dtype), indexing="ij"
torch.arange(end=height, dtype=dtype, device=device),
torch.arange(end=width, dtype=dtype, device=device),
indexing="ij",
)
grid_xy = torch.stack([grid_x, grid_y], -1)
valid_wh = torch.tensor([width, height]).to(dtype)
valid_wh = torch.tensor([width, height], device=device).to(dtype)
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh
wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
Expand Down Expand Up @@ -1826,7 +1867,7 @@ def forward(
# Pass spatial_shapes as tuple to make it hashable and make sure
# lru_cache is working for generate_anchors()
spatial_shapes_tuple = tuple(spatial_shapes_list)
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple)
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
else:
anchors, valid_mask = self.anchors, self.valid_mask

Expand Down Expand Up @@ -1873,6 +1914,7 @@ def forward(
encoder_attention_mask=attention_mask,
reference_points=init_reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down