Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
81f044d
First commit for DINO - DETR. Copy-paste files from Deformable DETR a…
Feb 28, 2025
e4f726d
Add dino detr loss
Mar 1, 2025
216c444
First implementation of dino detr post processor
Mar 1, 2025
e7ea7f9
Add misc functions to dino modeling, make config first version
Mar 2, 2025
19f79d2
Restructure dino detr modeling code, major changes are now required o…
Mar 2, 2025
f2f3d02
Match config to source code in modeling dino detr, defaults in config…
Mar 2, 2025
2e438d8
Clean up dino_detr_config
Mar 3, 2025
efcadcf
Fixed import errors, code can now be executed
Mar 3, 2025
b9f6f15
Model now can be instantiated with random weights
Mar 4, 2025
a5c9f5a
Forward pass for randomly initialized model now executes without errors
Mar 6, 2025
e464c28
Forward pass works for DINO model with pretrained weights and matches…
Mar 10, 2025
51c565c
Add dino detr conversion to pytorch script
Mar 10, 2025
e2dcc0c
Ensure model outputs for single COCO image are within 1e-3 error of r…
Mar 13, 2025
14dfc25
Add draft unit tests
Mar 13, 2025
87065ac
Apply make style
Mar 13, 2025
254adfa
Modeling integration test passes
Mar 22, 2025
b9daab5
Add integration and unit tests
Mar 24, 2025
d20be46
Remove unnecessary dependency on pretrained model class from some cla…
Apr 7, 2025
db789ed
All image processing tests pass
Apr 14, 2025
4efe590
Consistently start module names with DinoDetr
Apr 14, 2025
ebd4fa6
Pass test_hidden_states_output test
Apr 14, 2025
72689b3
Dino detr doesn't support gradient checkpointing
Apr 17, 2025
e8f055b
Pass test_model_is_small test
Apr 18, 2025
5bfe3a1
Pass test_model_outputs_equivalence test
Apr 18, 2025
35181a0
Pass test_attention_outputs
Apr 19, 2025
c76d6e3
Pass test_retain_grad_hidden_states_attentions
Apr 19, 2025
c5f3cb0
Pass safetensor tests
Apr 19, 2025
00e2932
All tests pass for the first model draft
Apr 24, 2025
f7bac4c
Convert weights to hf using a single key mapping and regex
Apr 24, 2025
6454cad
Add some missing copied_from comments, these work with the huggingfac…
Apr 27, 2025
1eec3f7
Move all asserts to the config
Apr 27, 2025
b9cfd1e
Remove some unused config parameters
Apr 27, 2025
28de022
Remove unused config parameters, simplify inits
Apr 27, 2025
9aa7b73
Clean up decoder code
Apr 29, 2025
1e6567f
Clean up deformable transformer code
Apr 29, 2025
52f9e8a
Clean DinoDetrModel code
Apr 29, 2025
4bfa659
Clean up DinoDetrForObjectDetection class
Apr 29, 2025
476c4fa
Add type hints
May 4, 2025
bc845c7
Add a model checkpoint to hub
May 4, 2025
8eef2a1
Fix copied from in image_processing_dino_detr
May 4, 2025
0ca2f6a
Some changes to make copied from mechanism work
May 5, 2025
479cdd1
Fix more code quality errors
May 5, 2025
7cf7ae0
Fill missing values config docstring
May 5, 2025
d1556d1
Add dino detr to models/__init__.py file
May 5, 2025
388cae3
DinoDetrDeformableTransformer is recognized as a model, attempt a fix
May 5, 2025
8f4dcf4
Attempt fix for DinoDetrDeformableTransformer getting classified as a…
May 5, 2025
a817f84
Add docs template
May 5, 2025
e3c3a53
Add missing model type to the constant
May 5, 2025
ca9dd3c
Remove config from loss function definition
May 7, 2025
9ab7bb8
Remove unused config parameters
May 7, 2025
6671acf
DinoDetr docs first complete draft
May 8, 2025
abb3a67
Clean docstrings
May 8, 2025
ea928f0
Add docs to toctree.yml
May 8, 2025
a059714
Fix some errors after rebase
May 11, 2025
58fa28b
Update deformable attention definition
May 11, 2025
aed7208
Implement using modular
May 11, 2025
7f5f8cd
Apply conditions on some imports
May 11, 2025
98f7f25
Remove asserts from postprocessing method
May 11, 2025
d201f21
Sort imports
May 11, 2025
c367c4e
Small fixes
Oct 9, 2025
9703b52
Style changes from make fixup
Oct 9, 2025
14d4dc8
Small fixes
Oct 9, 2025
0b818da
Address reviewer comments
Oct 9, 2025
67c1d50
Clean up dino_loss import from detection_loss when possible
Oct 12, 2025
2f4d20a
Clean up configuration_dino_detr.py
Oct 12, 2025
ca3073a
Clean up convert_dino_detr_to_pytorch.py
Oct 12, 2025
8af3f3f
Remove feature extraction file
Oct 12, 2025
9818289
Address reviewer comments
Oct 12, 2025
aa07ace
Replace _get_activation_fn with ACT2FN
Oct 12, 2025
b01762a
Use get_contrastive_denoising_training_group from rt_detr
Oct 13, 2025
c90bd59
Reuse deformable_detr image_processor
Oct 13, 2025
8d26a08
Simplify config
Oct 13, 2025
1380dd5
Clean up modular code, add comments
Oct 14, 2025
f1163d7
Cleanup of embeddings
Oct 15, 2025
02a1197
Cleanup embeddings
Oct 15, 2025
8442405
Cleanup: remove layer dropout, two_stage options, ref_token, etc
Oct 20, 2025
578585f
Clean modeling code
Oct 21, 2025
626d62d
Add DinoDetrImageProcessor to __all__
Oct 22, 2025
116d400
Make fix copies
Oct 22, 2025
323e567
Revert changes to test modeling deformable detr
Oct 23, 2025
3b3b07a
Small fix
Oct 23, 2025
922ac92
Fix docstring
Oct 23, 2025
5987ad9
Ignore docstring consistency, it throws misleading errors
Oct 23, 2025
df64313
Add architecture link to config docstring
Oct 23, 2025
348cb35
Replace return_dict with can_return_tuple
Nov 16, 2025
45b6d13
Use _can_record_outputs instead of output_attentions output_hidden_st…
Nov 18, 2025
a97f835
Simplify code
Nov 18, 2025
5cda82b
Simplify layer share
Nov 18, 2025
a4e4bc6
Make code simpler
Nov 20, 2025
b89edb0
Rename cardinality loss to cardinality error
Nov 20, 2025
c0acd49
Remove redundant imports from modular
Nov 21, 2025
01692bb
Improve docs
Nov 21, 2025
a7326d8
Fix failing tests after rebase
Nov 24, 2025
701a7ad
Fix most tests apart from test_can_use_safetensors and test_load_save…
Dec 3, 2025
8128b03
Remove tied weights
Dec 3, 2025
9f28b71
Simplify code
Dec 9, 2025
83c818f
Simplify _can_record_outputs
Dec 9, 2025
b578c84
Add auto_docstring
Dec 9, 2025
a9647a5
Simplify EmbeddingSineHW
Dec 9, 2025
b10f26a
Simplify code
Dec 9, 2025
7f6c0f5
Simplify decoder layer
Dec 9, 2025
08b1130
Remove get_clones completely
Dec 9, 2025
97c723a
Fix auto_docstring
Dec 9, 2025
5a3fec3
Fix ini_weights
Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@
title: DETR
- local: model_doc/dinat
title: DiNAT
- local: model_doc/dino_detr
title: DINO DETR
- local: model_doc/dinov2
title: DINOV2
- local: model_doc/dinov2_with_registers
Expand Down
130 changes: 130 additions & 0 deletions docs/source/en/model_doc/dino_detr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# DINO DETR

<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

## Overview

DINO DETR (DETR with Improved DeNoising Anchor Boxes) is a state-of-the-art end-to-end object detection model introduced in the paper [DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection](https://arxiv.org/abs/2203.03605) by Hao Zhang, Feng Li, Shilong Liu, Lei Zhang, Hang Su, Jun Zhu, Lionel M. Ni and Heung-Yeung Shum. It builds upon the original DETR framework by addressing key challenges in convergence speed and detection accuracy.

DINO DETR enhances the DETR architecture through three main innovations:

* **Mixed Query Selection for Anchor Initialization**: In DINO DETR decoder queries consist of query locations and query features. The decoder query locations are selected by passing the encoder features through a classification head and selecting the topk locations in terms of max class probability. The decoder query features are learnable weights shared across all samples. In this way, the query locations are initialized close to interesting objects and the query features are made robust to different object classes.

* **Contrastive Denoising Training**: DINO Detr uses denoising queries together with "standard" DETR queries. The denoising queries are strongly perturbed versions of standard queries, assigned to negative labels. This improves the model's robustness and convergence speed.

* **Look Forward Twice Scheme for Box Prediction**: This term means that the bounding boxes are refined iteratively from one decoder layer to the next, by adding corrections to the previous layer bounding boxes. This improves training stability.

These advancements enable DINO DETR to achieve significant performance improvements over previous DETR-like models. For instance, with a ResNet-50 backbone and multi-scale features, DINO attains 49.4 AP in 12 epochs and 51.3 AP in 24 epochs on the COCO dataset, marking a substantial enhancement in detection performance.

The abstract of the paper is the following:

*We present DINO (DETR with Improved deNoising anchOr boxes), a state-of-the-art end-to-end object detector. DINO improves over previous DETR-like models in performance and efficiency by using a contrastive way for denoising training, a mixed query selection method for anchor initialization, and a look forward twice scheme for box prediction. DINO achieves 49.4AP in 12 epochs and 51.3AP in 24 epochs on COCO with a ResNet-50 backbone and multi-scale features, yielding a significant improvement of +6.0AP and +2.7AP, respectively, compared to DN-DETR, the previous best DETR-like model. DINO scales well in both model size and data size. Without bells and whistles, after pre-training on the Objects365 dataset with a SwinL backbone, DINO obtains the best results on both COCO val2017 (63.2AP) and test-dev (63.3AP). Compared to other models on the leaderboard, DINO significantly reduces its model size and pre-training data size while achieving better results.*

This implementation is contributed by [kostaspitas](https://huggingface.co/kostaspitas) and is based on the official code available at [https://github.com/IDEA-Research/DINO](https://github.com/IDEA-Research/DINO).

### Model Architecture

DINO DETR with Improved DeNoising Anchor Boxes (DINO) follows a similar architecture to the original DETR but introduces enhancements for improved performance and efficiency.

#### 1. Backbone

An input image is processed through a pre-trained convolutional backbone, such as ResNet-50 or ResNet-101. This backbone extracts multi-scale feature maps, which are then projected to match the hidden dimension of the Transformer encoder. For instance, a feature map of shape `(batch_size, 2048, height/32, width/32)` is transformed to `(batch_size, 256, height/32, width/32)` using a convolutional layer. These feature maps are then flattened and transposed to obtain a tensor of shape `(batch_size, seq_len, d_model)`, where `seq_len` is the number of spatial locations and `d_model` is the model dimension (e.g., 256).

#### 2. Transformer Encoder

The flattened feature maps are passed through a multi-layer transformer encoder. The self-attention implementation is deformable attention as first introduced in the [Deformable DETR](https://arxiv.org/abs/2010.04159) paper. This encoder processes the input sequence and outputs `encoder_hidden_states`, which serve as the image features for the subsequent decoder.

#### 3. Object Queries and Decoder

DINO introduces dynamic anchor boxes as object queries. These queries are initialized based on anchor box coordinates and are updated through the decoder layers. The decoder receives these object queries along with the encoder outputs and processes them through multiple self-attention and encoder-decoder cross-attention layers. The decoder cross attention layers perform deformable attention while self-attention is standard multi-head self-attention. The result is `decoder_hidden_states`, which are then used for prediction.

#### 4. Prediction Heads

On top of the decoder outputs, DINO adds two prediction heads:

* **Classification Head**: A linear layer that classifies each object query into one of the object categories or "no object".

* **Bounding Box Head**: A multi-layer perceptron (MLP) that predicts the bounding box coordinates for each object query.

#### 5. Training with Bipartite Matching Loss

DINO employs a bipartite matching loss during training. The predicted classes and bounding boxes for each object query are compared to the ground truth annotations, padded to the same length. The Hungarian algorithm is used to find an optimal one-to-one mapping between the predicted and ground truth annotations. The loss function combines focal loss for classification and a linear combination of L1 loss and generalized Intersection over Union (IoU) loss for bounding box regression. In addition to the original DETR queries, DINO DETR adds a denoising query set which feeds noised groundtruth labels and boxes into the decoder to provide an aixiliary denoising loss. The denoising loss effectively stabilizes and speeds up the DINO DETR training.
Comment on lines +67 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are new losses introduced, would be cool to have a code snippet here to guide people through training as well! Would add a lot of value to the release


## Usage Tips

- **Object Queries**: DINO utilizes dynamic anchor boxes as object queries to detect objects within an image. The number of queries (`num_queries`) determines the maximum number of objects that can be detected in a single image. By default, this is set to 900 (e.g., 300 queries × 3 patterns) to enhance detection performance.
- **Decoder Parallelism**: Similar to the original DETR, DINO's decoder updates object queries in parallel, differing from autoregressive models like GPT-2. Consequently, no causal attention mask is employed.
- **Position Embeddings**: DINO adds position embeddings to the hidden states at each self-attention and cross-attention layer before projecting to queries and keys. For image position embeddings, you can choose between fixed sinusoidal or learned absolute position embeddings. By default, the `position_embedding_type` parameter in `DinoDetrConfig` is set to `"SineWH"`.
- **Auxiliary Losses**: During training, employing auxiliary losses in the decoder can be beneficial, especially for improving the model's ability to predict the correct number of objects per class. Setting the `auxiliary_loss` parameter in `DinoDetrConfig` to `True` adds an auxiliary_loss after each decoder layer.
- **Distributed Training**: When training the model across multiple nodes, it's important to update the `num_boxes` variable in the `DinoLoss` class to reflect the average number of target boxes across all nodes. This adjustment ensures proper loss computation during distributed training.
- **Backbone Initialization**: `DinoDetrForObjectDetection` can be initialized with any convolutional backbone available in the [timm library](https://github.com/rwightman/pytorch-image-models). For instance, to use a MobileNet backbone, set the `backbone` attribute in `DinoDetrConfig` to `"tf_mobilenetv3_small_075"` and initialize the model with this configuration.
- **Image Preprocessing**: DINO resizes input images such that the shortest side is at least a certain number of pixels, while the longest side is at most 1333 pixels. During training, scale augmentation is applied, randomly setting the shortest side to between 480 and 800 pixels and its longer side to be at most 1333 pixels. At inference time, the shortest side is set to 800 pixels. Use `DinoDetrImageProcessor` to prepare images (and optional annotations in COCO format) for the model. Due to resizing, images in a batch may have different sizes. DinoDetr addresses this by padding images to the largest size in the batch and creating a pixel mask to differentiate real pixels from padding. Alternatively, you can define a custom `collate_fn` to batch images using `DinoDetrImageProcessor.pad_and_create_pixel_mask`.
- **Batch Size Considerations**: The size of input images affects memory usage and, consequently, the `batch_size`.

### Model Initialization Options

There are three ways to instantiate a DinoDetr model:

**Option 1**: Instantiate DinoDetr with pre-trained weights for the entire model
```python
>>> from transformers import DinoDetrForObjectDetection

>>> model = DinoDetrForObjectDetection.from_pretrained("IDEA-Research/dino-resnet-50")
```
**Option 2**: Instantiate DinoDetr with randomly initialized Transformer weights but pre-trained backbone weights
```python
>>> from transformers import DinoDetrConfig, DinoDetrForObjectDetection

>>> config = DinoDetrConfig()
>>> model = DinoDetrForObjectDetection(config)
```
**Option 3**: Instantiate DinoDetr with randomly initialized weights for both backbone and Transformer
```python
>>> config = DinoDetrConfig(use_pretrained_backbone=False)
>>> model = DinoDetrForObjectDetection(config)
```
One should prepare the data in COCO detection format, then use
[`~transformers.DetrImageProcessor`] to create `pixel_values`, `pixel_mask` and optional
`labels`, which can then be used to train (or fine-tune) a model. For evaluation, one should first convert the
outputs of the model using one of the postprocessing methods of [`~transformers.DetrImageProcessor`]. These can
be provided to either `CocoEvaluator` which allow you to calculate metrics like
mean Average Precision (mAP).

## DinoDetrModel

[[autodoc]] DinoDetrModel

## DinoDetrForObjectDetection

[[autodoc]] DinoDetrForObjectDetection

## DinoDetrConfig

[[autodoc]] DinoDetrConfig

## DinoDetrImageProcessor

[[autodoc]] DinoDetrImageProcessor

## DinoDetrFeatureExtractor

[[autodoc]] DinoDetrFeatureExtractor

1 change: 0 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]


# Direct imports for type-checking
if TYPE_CHECKING:
# All modeling imports
Expand Down
Loading