-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Add RF-DETR #36895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add RF-DETR #36895
Conversation
|
Just a small message to present the architecture and what it looks like from 🤗 transformers point of view : RF-DETR is based on LW-DETR and DeformableDETR. The LW-DETR is based on DETR but modified the encoder to be a ViT instead of a CNN (like ResNet) and they added the appropriate MultiScaleProjector to make the link between the encoder and the decoder. RF-DETR changed in LW-DETR the encoder from a ViT to DinoV2WithRegisters with a "window" mechanism as well as changed the classical DETR decoder by a DeformableDETR decoder. There is basically 2 things to write :
One difficulty I may see in advance is the following : I noticed your PR about refactoring attention in ViTs, is there any plan for other models such as Detr, RTDetr etc to add FlashAttention ? Let me know what you guys think |
|
Hi @sbucaille, thanks for the detailed write-up!
We can add
Not at the moment, from my experiments it was not required for detr-based models and did not give any speedup. However, it might be more relevant for transformer-based encoder. Let's keep it simple initially and set it to False as you suggested |
|
We can't use the class RFDetrBackboneLayer(Dinov2WithRegistersLayer):
def __init__(self, config):
super(Dinov2WithRegistersLayer).__init__(config)
self.num_windows = config.num_windows
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
run_full_attention: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
assert head_mask is None, "head_mask is not supported for windowed attention"
assert not output_attentions, "output_attentions is not supported for windowed attention"
shortcut = hidden_states
if run_full_attention:
# reshape x to remove windows
B, HW, C = hidden_states.shape
num_windows_squared = self.num_windows**2
hidden_states = hidden_states.view(B // num_windows_squared, num_windows_squared * HW, C)
self_attention_outputs = self.attention(
self.norm1(hidden_states), # in Dinov2WithRegisters, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
if run_full_attention:
# reshape x to add windows back
B, HW, C = hidden_states.shape
num_windows_squared = self.num_windows**2
# hidden_states = hidden_states.view(B * num_windows_squared, HW // num_windows_squared, C)
attention_output = attention_output.view(B * num_windows_squared, HW // num_windows_squared, C)
attention_output = self.layer_scale1(attention_output)
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection
hidden_states = self.drop_path(attention_output) + shortcut
# in Dinov2WithRegisters, layernorm is also applied after self-attention
layer_output = self.norm2(hidden_states)
layer_output = self.mlp(layer_output)
layer_output = self.layer_scale2(layer_output)
# second residual connection
layer_output = self.drop_path(layer_output) + hidden_states
outputs = (layer_output,) + outputs
return outputsThat's why I think we necessarily need a custom Backbone class for that 🤔 |
|
Hmm, am I correct that this part was added? It looks like it is a reshape only operation, we can return attention_output as is, and reshape all layers output later, right? |
|
You are right, but it is not the only example. I'll stick to my original plan until I have something running with actual results and I'll take care of refactoring this part later, I'll ping you when it's ready. |
|
Hey @qubvel, in the end I made modeling files follow the Also I had issues with the modular mechanism where |
|
Hey, let's use RfDert name + modular, it's ok! RfDetr is a correct naming format while RTDetr is an exception made before modular was introduced |
|
Ok sorry I confused the problems I had, I didn't have a problem with the capital letters of class RfDetrModel(DeformableDetrModel):
passgenerates a bunch of class RfDetrConvEncoder(DeformableDetrConvEncoder):
pass
class RfDetrModel(DeformableDetrModel):
def __init__(self, config: RfDetrConfig):
super().__init__(config)
backbone = RfDetrConvEncoder(config)
...But the problem also appears for Should I open an issue ? Maybe @ArthurZucker have some insights on this problem ? |
|
cc @Cyrilvallez re modular you faced somthing similar |
|
I'm also facing a similar issue with @sbucaille while working on DinoDetr. |
|
Hey! Super super sorry, I missed the ping! Indeed, model sharing the last part of their names can introduce issues. I think I have an idea that should fix it, while keeping general prefix renaming sound (which is very hard in practice)! I'll try to tackle it asap and will come back to you! |
|
Hey @sbucaille @konstantinos-p! It will be solved by #37829 🤗 I will merge asap! Sorry for the wait on this! EDIT: Just merged the PR! |
2b88599 to
50829ac
Compare
|
@yonigozlan Ready for a first review, this branch is based on the LWDetr until it will be merged |
46bd06b to
105f3da
Compare
yonigozlan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @sbucaille , very nice PR! I mentioned some small things to change, but once lw detr is merged we should be able to merge this quickly!
| self.register_tokens = ( | ||
| nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) | ||
| if config.num_register_tokens > 0 | ||
| else None | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like num_register_tokens is 0 for all models in the convert file. So do we really need all the logic associated to it? Can we inherit from dinov2 instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, fixed in b89a4af
| window_block_indexes = set(range(self._out_indices[-1] + 1)) | ||
| window_block_indexes.difference_update(self._out_indices) | ||
| window_block_indexes = list(window_block_indexes) | ||
| self.window_block_indexes = window_block_indexes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit verbose and hard to read. Instead, let's hardcode the window_block_indices in the config like we do for vitdet for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, in my opinion, I think it is still better this way as window_block_indices consists of the inverse set of out_indices, this way we don't have to check whether the provided window_block_indices are valid or not and raise an error if it is not the case.
| batch_size * num_windows**2, num_h_patches_per_window * num_w_patches_per_window, -1 | ||
| ) | ||
| windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows**2, 1, 1) | ||
| embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put all that in a window_partition utility, like we do for vitdet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 68c8918
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| remove_windows: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a use_global_attention attribute to the layers when we instantiate them instead of passing an arg to the forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored in 6a345f8
| if remove_windows: | ||
| # reshape x to remove windows | ||
| B, HW, C = hidden_states.shape | ||
| num_windows_squared = self.num_windows**2 | ||
| hidden_states = hidden_states.view(B // num_windows_squared, num_windows_squared * HW, C) | ||
|
|
||
| hidden_states_norm = self.norm1(hidden_states) | ||
| self_attention_output = self.attention(hidden_states_norm) | ||
|
|
||
| if remove_windows: | ||
| # reshape x to add windows back | ||
| B, HW, C = hidden_states.shape | ||
| num_windows_squared = self.num_windows**2 | ||
| # hidden_states = hidden_states.view(B * num_windows_squared, HW // num_windows_squared, C) | ||
| self_attention_output = self_attention_output.view(B * num_windows_squared, HW // num_windows_squared, C) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use window_partition and window_unpartition utilities here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 68c8918
| import io | ||
|
|
||
| import requests | ||
| from PIL import Image | ||
|
|
||
| from transformers import AutoImageProcessor, RFDetrBackbone, RFDetrConfig | ||
|
|
||
|
|
||
| images = ["https://media.roboflow.com/notebooks/examples/dog-2.jpeg"] | ||
|
|
||
| images = [Image.open(io.BytesIO(requests.get(url).content)) for url in images] | ||
|
|
||
| processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50") | ||
| inputs = processor(images, return_tensors="pt") | ||
|
|
||
| config = RFDetrConfig() | ||
| backbone = RFDetrBackbone(config=config.backbone_config) | ||
| # model = RFDetrForObjectDetection.from_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouupps, removed in 2151d0d
| @@ -0,0 +1,357 @@ | |||
| from typing import Optional | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice use of modular 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, you guys have done an insane work on this feature, quality of life when adding models compared to before is on another level
| def num_key_value_heads(self) -> int: | ||
| return self.decoder_self_attention_heads | ||
|
|
||
| @property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftovers from LW Detr, I'll remove them on the other PR so when I'll rebase in the future it will be gone
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class RfDetrConfig(PretrainedConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can still inherit form LwDetrConfig for the properties
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 0697a29
| ("llava_next_video", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), | ||
| ("llava_next_video", ("LlavaNextVideoImageProcessor", None)), | ||
| ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")), | ||
| ("lw_detr", ("LwDetrImageProcessor", "LwDetrImageProcessorFast")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still need a mapping for RFDETR as I mentioned in the LW-DETR PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in a6d5e2c
105f3da to
0697a29
Compare
4d1f9e3 to
fb0198a
Compare
|
@molbap @vasqu nvm it's ready for a review I have a question regarding the WeightTransforms, should they be only in the convert script and saved on the hub without reverse mapping like I did, or should they be in the conversion_mapping file with the weights saved as the original ? |
stevhliu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs lgtm, thanks!
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, lw_detr, rf_detr |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=36895&sha=1599f8 |

What does this PR do?
Implements RF-DETR
Fixes #36879
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@qubvel