Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b2229ba
feat: Add RfDetr initial commit
sbucaille Mar 26, 2025
0ce186d
tests: RfDetr and RfDetrDinov2WithRegisters tests
sbucaille Sep 26, 2025
be572a0
docs: added RfDetr docs
sbucaille Sep 26, 2025
6b6a813
refactor: apply lwdetr changes to rfdetr
sbucaille Oct 22, 2025
6fcd8ca
refactor: apply changes after lwdetr rebase
sbucaille Oct 30, 2025
8bd6b65
tests: added missing _prepare_for_class method
sbucaille Oct 30, 2025
7b8da5a
refactor: removed registers and moved from DinoV2WithRegisters to sim…
sbucaille Nov 7, 2025
42a996a
feat: updated RfDetr and RfDetrDinov2 to comply with original impleme…
sbucaille Nov 7, 2025
77c3d41
chore: make style
sbucaille Nov 7, 2025
fee5d03
refactor: extracted window partition and unpartition in dedicated fun…
sbucaille Nov 7, 2025
87e979f
refactor: changed from remove_window parameter to global_attention at…
sbucaille Nov 7, 2025
ea7c04e
chore: removed mistakenly committed file
sbucaille Nov 7, 2025
f340136
feat: added rf_detr image processing reference
sbucaille Nov 7, 2025
9175132
refactor: added LwDetrConfig as superclass for RfDetrConfig
sbucaille Nov 7, 2025
83976de
tests: add large RfDetr integration test
sbucaille Nov 7, 2025
0372f0e
chore: make style
sbucaille Nov 7, 2025
8075163
chore: make repo-consistency
sbucaille Nov 8, 2025
d7cac5a
chore: make quality
sbucaille Nov 8, 2025
a30a422
fix: rebased on lw_detr branch
sbucaille Nov 25, 2025
8017a12
chore: moved RfDetrDinov2 to a single modular_rf_detr.py file
sbucaille Dec 10, 2025
310d93a
fix: fixed backbone name in weight for convert script
sbucaille Dec 10, 2025
681c9e3
refactor: fixed modular file after fix
sbucaille Dec 12, 2025
b6d2bca
docs: fix licences
sbucaille Jan 12, 2026
7491860
fix: moved variables to config for ScaleProjector
sbucaille Jan 13, 2026
98abf6b
tests: remove deprecated tests
sbucaille Jan 13, 2026
8ba3382
docs: remove unnecessary utf-8
sbucaille Jan 13, 2026
7f83d9a
docs: add RF-DETR to toctree
sbucaille Jan 13, 2026
6ddfcdd
style
sbucaille Jan 13, 2026
b5a8a78
tests: add cpu device in expectations
sbucaille Jan 13, 2026
245b13f
style: reordered classes
sbucaille Jan 13, 2026
07b926f
fix: removed unnecessary variable
sbucaille Jan 13, 2026
82e112a
style: make style
sbucaille Jan 13, 2026
e0ad128
fix: removed backbone API related attributes from config
sbucaille Jan 13, 2026
b1fc550
fix: added class_loss_coefficient to LWDetr config and loss
sbucaille Jan 21, 2026
403f8cf
feat: add RfDetrForInstanceSegmentation model
sbucaille Jan 22, 2026
28f35f1
docs: polish docs
sbucaille Jan 22, 2026
613c2f0
fix: lwdetr unnecessary config attribute
sbucaille Jan 23, 2026
fb0198a
chore: apply modular to rfdetr
sbucaille Jan 23, 2026
e2f06dc
fix: change from deformabledetr to detr image processor
sbucaille Jan 23, 2026
d55ca21
Update docs/source/en/model_doc/rf_detr.md
sbucaille Jan 23, 2026
0c8202d
docs: update convert script docs
sbucaille Jan 23, 2026
1599f8b
feat: add new segmentation models
sbucaille Jan 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@
title: RegNet
- local: model_doc/resnet
title: ResNet
- local: model_doc/rf_detr
title: RF-DETR
- local: model_doc/rt_detr
title: RT-DETR
- local: model_doc/rt_detr_v2
Expand Down
139 changes: 139 additions & 0 deletions docs/source/en/model_doc/rf_detr.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your question: yes, all WeigtRenaming ops live in conversion_mapping.py for now, as a matter of fact we try to get rid of conversion scripts entirely. Reviewing to see if we can absorb it

Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
<!--Copyright 2026 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
*This model was released on 2024-04-05 and added to Hugging Face Transformers on 2026-01-23.*

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>

# RF-DETR

[RF-DETR](https://huggingface.co/papers/2407.17140) proposes a Receptive Field Detection Transformer (DETR) architecture
designed to compete with and surpass the dominant YOLO series for real-time object detection. It achieves a new
state-of-the-art balance between speed (latency) and accuracy (mAP) by combining recent transformer advances with
efficient design choices.

The RF-DETR architecture is characterized by its simple and efficient structure: a DINOv2 Backbone, a Projector, and a
shallow DETR Decoder.
It enhances the DETR architecture for efficiency and speed using the following core modifications:

1. **DINOv2 Backbone**: Uses a powerful DINOv2 backbone for robust feature extraction.
2. **Group DETR Training**: Utilizes Group-Wise One-to-Many Assignment during training to accelerate convergence.
3. **Richer Input**: Aggregates multi-level features from the backbone and uses a C2f Projector (similarly to YOLOv8) to
pass multi-scale features.
4. **Faster Decoder**: Employs a shallow 3-layer DETR decoder with deformable cross-attention for lower latency.
5. **Optimized Queries**: Uses a mixed-query scheme combining learnable content queries and generated spatial queries.

You can find all the available RF-DETR checkpoints under the [stevenbucaille](https://huggingface.co/stevenbucaille)
organization.
The original code can be found [here](https://github.com/roboflow/rf-detr).

> [!TIP]
> This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
>
> Click on the RF-DETR models in the right sidebar for more examples of how to apply RF-DETR to different object
> detection tasks.

The example below demonstrates how to perform object detection with the [`Pipeline`] and the [`AutoModel`] class.

<hfoptions id="usage">
<hfoption id="Pipeline">

```python
from transformers import pipeline
import torch

pipeline = pipeline(
"object-detection",
model="stevenbucaille/rfdetr_small_60e_coco",
dtype=torch.float16,
device_map=0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be included in the pipeline?

)

pipeline("http://images.cocodataset.org/val2017/000000039769.jpg")
```

</hfoption>
<hfoption id="AutoModel">

```python
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image
import requests
import torch

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small")
model = AutoModelForObjectDetection.from_pretrained("stevenbucaille/rfdetr_small")

# prepare image for the model
inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)

results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)

for result in results:
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
score, label = score.item(), label_id.item()
box = [round(i, 2) for i in box.tolist()]
print(f"{model.config.id2label[label]}: {score:.2f} {box}")
```

</hfoption>
</hfoptions>

## Resources


- Scripts for finetuning [`RfDetrForObjectDetection`] with [`Trainer`]
or [Accelerate](https://huggingface.co/docs/accelerate/index) can be
found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection).

## RfDetrConfig

[[autodoc]] RfDetrConfig

## RfDetrDinov2Config

[[autodoc]] RfDetrDinov2Config

## RfDetrModel

[[autodoc]] RfDetrModel
- forward

## RfDetrForObjectDetection

[[autodoc]] RfDetrForObjectDetection
- forward

## RfDetrForInstanceSegmentation

[[autodoc]] RfDetrForInstanceSegmentation
- forward

## RfDetrDinov2Backbone

[[autodoc]] RfDetrDinov2Backbone
- forward
2 changes: 1 addition & 1 deletion src/transformers/loss/loss_lw_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def LwDetrForObjectDetectionLoss(
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
weight_dict = {"loss_ce": config.class_loss_coefficient, "loss_bbox": config.bbox_loss_coefficient}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, but backwards compat breaking if a config was ill-defined - not certain if that warrants adding 🚨 to the PR description

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought but pulling a config from the hub populates attribute with default values if no value is found right ? All LWDetr tests pass, it is just a drop-in replacement

weight_dict["loss_giou"] = config.giou_loss_coefficient
if config.auxiliary_loss:
aux_weight_dict = {}
Expand Down
Loading