-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Add RF-DETR #36895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add RF-DETR #36895
Changes from all commits
b2229ba
0ce186d
be572a0
6b6a813
6fcd8ca
8bd6b65
7b8da5a
42a996a
77c3d41
fee5d03
87e979f
ea7c04e
f340136
9175132
83976de
0372f0e
8075163
d7cac5a
a30a422
8017a12
310d93a
681c9e3
b6d2bca
7491860
98abf6b
8ba3382
7f83d9a
6ddfcdd
b5a8a78
245b13f
07b926f
82e112a
e0ad128
b1fc550
403f8cf
28f35f1
613c2f0
fb0198a
e2f06dc
d55ca21
0c8202d
1599f8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| <!--Copyright 2026 The HuggingFace Team. All rights reserved. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
| --> | ||
| *This model was released on 2024-04-05 and added to Hugging Face Transformers on 2026-01-23.* | ||
|
|
||
| <div style="float: right;"> | ||
| <div class="flex flex-wrap space-x-1"> | ||
| <img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| </div> | ||
| </div> | ||
|
|
||
| # RF-DETR | ||
|
|
||
| [RF-DETR](https://huggingface.co/papers/2407.17140) proposes a Receptive Field Detection Transformer (DETR) architecture | ||
| designed to compete with and surpass the dominant YOLO series for real-time object detection. It achieves a new | ||
| state-of-the-art balance between speed (latency) and accuracy (mAP) by combining recent transformer advances with | ||
| efficient design choices. | ||
|
|
||
| The RF-DETR architecture is characterized by its simple and efficient structure: a DINOv2 Backbone, a Projector, and a | ||
| shallow DETR Decoder. | ||
| It enhances the DETR architecture for efficiency and speed using the following core modifications: | ||
|
|
||
| 1. **DINOv2 Backbone**: Uses a powerful DINOv2 backbone for robust feature extraction. | ||
| 2. **Group DETR Training**: Utilizes Group-Wise One-to-Many Assignment during training to accelerate convergence. | ||
| 3. **Richer Input**: Aggregates multi-level features from the backbone and uses a C2f Projector (similarly to YOLOv8) to | ||
| pass multi-scale features. | ||
| 4. **Faster Decoder**: Employs a shallow 3-layer DETR decoder with deformable cross-attention for lower latency. | ||
| 5. **Optimized Queries**: Uses a mixed-query scheme combining learnable content queries and generated spatial queries. | ||
|
|
||
| You can find all the available RF-DETR checkpoints under the [stevenbucaille](https://huggingface.co/stevenbucaille) | ||
| organization. | ||
| The original code can be found [here](https://github.com/roboflow/rf-detr). | ||
|
|
||
| > [!TIP] | ||
| > This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). | ||
| > | ||
| > Click on the RF-DETR models in the right sidebar for more examples of how to apply RF-DETR to different object | ||
| > detection tasks. | ||
|
|
||
| The example below demonstrates how to perform object detection with the [`Pipeline`] and the [`AutoModel`] class. | ||
|
|
||
| <hfoptions id="usage"> | ||
| <hfoption id="Pipeline"> | ||
|
|
||
| ```python | ||
| from transformers import pipeline | ||
| import torch | ||
|
|
||
| pipeline = pipeline( | ||
| "object-detection", | ||
| model="stevenbucaille/rfdetr_small_60e_coco", | ||
| dtype=torch.float16, | ||
| device_map=0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this need to be included in the pipeline? |
||
| ) | ||
|
|
||
| pipeline("http://images.cocodataset.org/val2017/000000039769.jpg") | ||
| ``` | ||
|
|
||
| </hfoption> | ||
| <hfoption id="AutoModel"> | ||
|
|
||
| ```python | ||
| from transformers import AutoImageProcessor, AutoModelForObjectDetection | ||
| from PIL import Image | ||
| import requests | ||
| import torch | ||
|
|
||
| url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
| image = Image.open(requests.get(url, stream=True).raw) | ||
|
|
||
| image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small") | ||
| model = AutoModelForObjectDetection.from_pretrained("stevenbucaille/rfdetr_small") | ||
|
|
||
| # prepare image for the model | ||
| inputs = image_processor(images=image, return_tensors="pt") | ||
|
|
||
| with torch.no_grad(): | ||
| outputs = model(**inputs) | ||
|
|
||
| results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3) | ||
|
|
||
| for result in results: | ||
| for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]): | ||
| score, label = score.item(), label_id.item() | ||
| box = [round(i, 2) for i in box.tolist()] | ||
| print(f"{model.config.id2label[label]}: {score:.2f} {box}") | ||
| ``` | ||
|
|
||
| </hfoption> | ||
| </hfoptions> | ||
|
|
||
| ## Resources | ||
|
|
||
|
|
||
| - Scripts for finetuning [`RfDetrForObjectDetection`] with [`Trainer`] | ||
| or [Accelerate](https://huggingface.co/docs/accelerate/index) can be | ||
| found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection). | ||
| - See also: [Object detection task guide](../tasks/object_detection). | ||
|
|
||
| ## RfDetrConfig | ||
|
|
||
| [[autodoc]] RfDetrConfig | ||
|
|
||
| ## RfDetrDinov2Config | ||
|
|
||
| [[autodoc]] RfDetrDinov2Config | ||
|
|
||
| ## RfDetrModel | ||
|
|
||
| [[autodoc]] RfDetrModel | ||
| - forward | ||
|
|
||
| ## RfDetrForObjectDetection | ||
|
|
||
| [[autodoc]] RfDetrForObjectDetection | ||
| - forward | ||
|
|
||
| ## RfDetrForInstanceSegmentation | ||
|
|
||
| [[autodoc]] RfDetrForInstanceSegmentation | ||
| - forward | ||
|
|
||
| ## RfDetrDinov2Backbone | ||
|
|
||
| [[autodoc]] RfDetrDinov2Backbone | ||
| - forward | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -343,7 +343,7 @@ def LwDetrForObjectDetectionLoss( | |
| outputs_loss["auxiliary_outputs"] = auxiliary_outputs | ||
| loss_dict = criterion(outputs_loss, labels) | ||
| # Fourth: compute total loss, as a weighted sum of the various losses | ||
| weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient} | ||
| weight_dict = {"loss_ce": config.class_loss_coefficient, "loss_bbox": config.bbox_loss_coefficient} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good, but backwards compat breaking if a config was ill-defined - not certain if that warrants adding 🚨 to the PR description
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's what I thought but pulling a config from the hub populates attribute with default values if no value is found right ? All LWDetr tests pass, it is just a drop-in replacement |
||
| weight_dict["loss_giou"] = config.giou_loss_coefficient | ||
| if config.auxiliary_loss: | ||
| aux_weight_dict = {} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For your question: yes, all
WeigtRenamingops live inconversion_mapping.pyfor now, as a matter of fact we try to get rid of conversion scripts entirely. Reviewing to see if we can absorb it