Skip to content

🚨 Refactor DETR to updated standards#41549

Merged
yonigozlan merged 53 commits intohuggingface:mainfrom
yonigozlan:refactor-detr
Feb 2, 2026
Merged

🚨 Refactor DETR to updated standards#41549
yonigozlan merged 53 commits intohuggingface:mainfrom
yonigozlan:refactor-detr

Conversation

@yonigozlan
Copy link
Copy Markdown
Member

What does this PR do?

This PR aims at refactoring DETR as part of an effort to standardize vision models in the library, in the same vein as #41546.
Expect to see much more PRs like this for vision models as we approach v5!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +31 to +32
if not isinstance(line, str):
line = line.decode()
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line was an str when I tried to use this, not sure why! I can open a separate PR for it though

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I can try to remove this, maybe it's not an issue anymore. Thanks for the reminder 😁

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries haha :D can be removed for sure?

Comment on lines +1088 to +1134
if pixel_values is None and inputs_embeds is None:
raise ValueError("You have to specify either pixel_values or inputs_embeds")

if inputs_embeds is None:
batch_size, num_channels, height, width = pixel_values.shape
device = pixel_values.device

if pixel_mask is None:
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
vision_features = self.backbone(pixel_values, pixel_mask)
feature_map, mask = vision_features[-1]

# Apply 1x1 conv to map (N, C, H, W) -> (N, d_model, H, W), then flatten to (N, HW, d_model)
# (feature map and position embeddings are flattened and permuted to (batch_size, sequence_length, hidden_size))
projected_feature_map = self.input_projection(feature_map)
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
spatial_position_embeddings = (
self.position_embedding(shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask)
.flatten(2)
.permute(0, 2, 1)
)
flattened_mask = mask.flatten(1)
else:
batch_size = inputs_embeds.shape[0]
device = inputs_embeds.device
flattened_features = inputs_embeds
# When using inputs_embeds, we need to infer spatial dimensions for position embeddings
# Assume square feature map
seq_len = inputs_embeds.shape[1]
feat_dim = int(seq_len**0.5)
# Create position embeddings for the inferred spatial size
spatial_position_embeddings = (
self.position_embedding(
shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
device=device,
dtype=inputs_embeds.dtype,
)
.flatten(2)
.permute(0, 2, 1)
)
# If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
if pixel_mask is not None:
mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
flattened_mask = mask.flatten(1)
else:
# If no mask provided, assume all positions are valid
flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now truly supports passing input_embeds instead of silently doing nothing with it

Comment on lines +1149 to +1152
if decoder_inputs_embeds is not None:
queries = decoder_inputs_embeds
else:
queries = torch.zeros_like(object_queries_position_embeddings)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, truly supports decoder_inputs_embeds as input

attention_mask=None,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
attention_mask=decoder_attention_mask,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supports masking of queries (as advertised)

Comment on lines 948 to +978
@@ -967,65 +960,36 @@ def forward(
intermediate = () if self.config.auxiliary_loss else None

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue

layer_outputs = decoder_layer(
hidden_states = decoder_layer(
hidden_states,
combined_attention_mask,
object_queries,
query_position_embeddings,
attention_mask,
spatial_position_embeddings,
object_queries_position_embeddings,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truly supports attention mask on vision features (it was always None before)

@yonigozlan yonigozlan changed the title [WIP] Refactor DETR to updated standards Refactor DETR to updated standards Oct 14, 2025
@yonigozlan
Copy link
Copy Markdown
Member Author

Hello @molbap @ArthurZucker!
The long overdue refactor of DETR is ready for a first review. I'm waiting for your reviews to run fix-copies as this will have a lot of impacts on other models (through # Copied from for now, modular later ;) )

@yonigozlan yonigozlan changed the title Refactor DETR to updated standards 🚨 Refactor DETR to updated standards Oct 14, 2025
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial thoughts, focused on the masks/interface part

Comment thread src/transformers/models/detr/modeling_detr.py
Comment thread src/transformers/models/detr/modeling_detr.py Outdated
Comment thread src/transformers/models/detr/modeling_detr.py
Comment thread src/transformers/models/detr/modeling_detr.py
Comment thread src/transformers/models/detr/modeling_detr.py
Comment thread src/transformers/models/detr/modeling_detr.py Outdated
):
if use_attention_mask:
self.skipTest(
"This test uses attention masks which are not compatible with DETR. Skipping when use_attention_mask is True."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why tho? Are the attention masks perhaps 3D instead?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more that _test_eager_matches_sdpa_inference is not adapted to the vision space (+object queries here). It tries to add a "decoder_input_ids" to the inputs, plus the seqlen created for the dummy masks were wrong. Seeing as the function is already quite cluttered and difficult to read, I figured trying to add support for vision model directly there would not be ideal. We can either override the tests in this model specifically, or try to have a more general test for vision models. Another option would be to be able to parameterize the tests by providing how to find the correct seqlen and input names.
I would love some help on this!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, is this specific to detr or will we encounter more so for other models in the vision family? It's best to not skip too much if it comes down the line. Depending on how many are affected by this, we either should

  • Fix the base test, e.g. with parametrization, splitting the test a bit (more models with similar problems)
  • Overwrite the test and make specific changes (low amount of models with similar problems)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is with the test's base design indeed. It will lead to more skipped tests down the line because the division encoder/encoder-decoder/decoder isn't that clearly made. The amount of models with similar problems isn't "low" imo.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it will increase too with us fixing the attention masks for vision models, so we definitely need to improve the base test

@yonigozlan
Copy link
Copy Markdown
Member Author

Thanks for the review @vasqu ! I standardized attention and masking following your advice :)

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good from my side, amazing work! Just left some smaller comments but nothing crazy

Comment thread src/transformers/models/detr/modeling_detr.py
Comment thread src/transformers/models/detr/modeling_detr.py

_can_record_outputs = {
"hidden_states": DetrEncoderLayer,
"attentions": OutputRecorder(DetrSelfAttention, layer_name="self_attn", index=1),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the explicit output recorder, iirc DetrSelfAttention should work fine in itself

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here out of curiosity :D

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No indeed I can remove it :)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nudge as reminder

):
if use_attention_mask:
self.skipTest(
"This test uses attention masks which are not compatible with DETR. Skipping when use_attention_mask is True."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, is this specific to detr or will we encounter more so for other models in the vision family? It's best to not skip too much if it comes down the line. Depending on how many are affected by this, we either should

  • Fix the base test, e.g. with parametrization, splitting the test a bit (more models with similar problems)
  • Overwrite the test and make specific changes (low amount of models with similar problems)

@ArthurZucker ArthurZucker removed their request for review October 16, 2025 14:11
Copy link
Copy Markdown
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks niiiice
For the unhappy CI, let's throw the Check Copies away!

Comment thread src/transformers/models/detr/modeling_detr.py
):
if use_attention_mask:
self.skipTest(
"This test uses attention masks which are not compatible with DETR. Skipping when use_attention_mask is True."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is with the test's base design indeed. It will lead to more skipped tests down the line because the division encoder/encoder-decoder/decoder isn't that clearly made. The amount of models with similar problems isn't "low" imo.

Comment thread src/transformers/modeling_utils.py Outdated
"qwen2_5_vl",
"videollava",
"vipllava",
"detr",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, do we need to add this here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's what made me go crazy haha otherwise _checkpoint_conversion_mapping doesn't work.
Note that this is temporary and will be replaced by the new way to convert weights on the fly that @ArthurZucker and @Cyrilvallez are working on.

def __init__(self, config: DetrConfig):
super().__init__()
self.embed_dim = config.d_model
self.hidden_size = config.d_model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't that break BC? (at least on the attribute names)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what way? If users access it directly? In any case I think we really need to standardize these types of variable names, it might be worth slightly breaking BC imo

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah in case of non-config access. I agree I prefer to standardize

Comment thread src/transformers/models/detr/modeling_detr.py
Comment on lines 962 to 965
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactly the typical dropout interface, we can maybe take the occasion to update it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes 😫, I was scared of breaking BC in that case, but maybe it's not so important. It would be great to get rid of non standards dropout elsewhere as well really

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to break it in here, it does not affect inference and clearly it would be an improvement to get rid of it haha

Comment on lines 1026 to 1032
def freeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for _, param in self.backbone.model.named_parameters():
param.requires_grad_(False)

def unfreeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for _, param in self.backbone.model.named_parameters():
param.requires_grad_(True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these methods should really be user-side responsibilities 😨 I would be pro-removal! We can always communicate on it

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agreed, we could start a deprecation cycle, or just remove it for v5. It's present in several other vision models

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just asked @merveenoyan who's an avid finetuner and is not using these methods anymore, I think they were good initially but they're ok to go now. Agreed it's out of scope for current PR will create another to remove all of it (cc @ariG23498 as we chatted on finetuning too)

Comment thread src/transformers/models/detr/modeling_detr.py Outdated
def forward(self, q, k, mask: Optional[torch.Tensor] = None):
q = self.q_linear(q)
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on here my nit would be, if we can update a bit the single-letter variable names, that'd be great!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think we could even try to refactor this to use the standard attention module and only take the attention weights! It could be interesting to compare the performance of eager attention vs this implementation (conv2d instead of linear for key proj, and no multiplication by value) vs other attention impl.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahah that's a tough one to benchmark but indeed sounds good, LMK if you want to do that in this PR or move to another

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with @molbap internally and I think we agree that it doesn't make sense to force this merge just to split refactoring again. Let's aim for quality in this refactor

We will probably merge the model PR as is and add this to this refactor after merge. Otherwise, we will suffer on both sides - crunch time on the model PR and less quality on the refactor (e.g. another set of TODOs)

I've added a few smaller comments meanwhile

Comment thread src/transformers/models/detr/modeling_detr.py Outdated
Comment thread src/transformers/models/detr/modeling_detr.py Outdated
Comment thread src/transformers/models/detr/modeling_detr.py
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = F.softmax(attention_weights, -1).view(
attention_weights = softmax(attention_weights, -1).view(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird, would like to not have a direct import

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but I have issues with torch functional and torchvision functional aliases colliding in modular. I have this PR to fix it #43263, I'll change back when it's merged

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :)


hidden_states = inputs_embeds

encoder_states = () if output_hidden_states else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can we not add this to _can_record_outputs

Copy link
Copy Markdown
Member Author

@yonigozlan yonigozlan Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't managed to get something clean that works for now, the issue is this line:
encoder_states = encoder_states + (hidden_states[enc_ind],)
So the encoder_states/hidden_states cannot automatically be recorded. I'll see if some refactoring of the code can fix this

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
sources = []
for level, source in enumerate(encoder_outputs[0]):
for level, source in enumerate(encoder_outputs.last_hidden_state):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should force return_dict=True then for the encoder

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would mean we need to force return_dict=True everytime we want to access a named parameter of a submodule output? It doesn't look like that's what we do in the library. From my understanding, return_dict=False is only applied to the top-module output, and the submodule use return_dict=True by default
Here, we pop return_dict in the top module call:

return_dict_passed = kwargs.pop("return_dict", return_dict)

@@ -1,15 +1,53 @@
import math
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, lets' bring back the header with license where missing

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed thanks for the heads up!

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=41549&sha=7821c4

@yonigozlan
Copy link
Copy Markdown
Member Author

Hey @molbap @vasqu ! I added a small refactor to rt_detr, and I think we can merge this for now before the PR gets too big. This should make for a good basis to continue the refactoring work on vision models :)

Comment on lines +996 to +999
class RTDetrV2AIFILayer(nn.Module):
"""
AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, so this can be reused in other derived models?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's the idea! Also it allows for automatically capturing hidden_states

x_max = x_coords_masked.flatten(start_dim=-2).max(dim=-1).values + 1
x_min = (
torch.where(mask, x_coords_masked, torch.tensor(1e8, device=mask.device, dtype=dtype))
torch.where(mask, x_coords_masked, torch.tensor(torch.finfo(dtype).max))
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This was causing overflow issues in float16
Cc @zhang-prog @molbap

@yonigozlan yonigozlan enabled auto-merge (squash) February 2, 2026 21:33
@yonigozlan yonigozlan disabled auto-merge February 2, 2026 23:05
@yonigozlan yonigozlan merged commit aefa23a into huggingface:main Feb 2, 2026
32 of 40 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants