diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index d55a6da33a32..b93f3cc9bca2 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -869,6 +869,8 @@
title: RegNet
- local: model_doc/resnet
title: ResNet
+ - local: model_doc/rf_detr
+ title: RF-DETR
- local: model_doc/rt_detr
title: RT-DETR
- local: model_doc/rt_detr_v2
diff --git a/docs/source/en/model_doc/rf_detr.md b/docs/source/en/model_doc/rf_detr.md
new file mode 100644
index 000000000000..0e11952c37cb
--- /dev/null
+++ b/docs/source/en/model_doc/rf_detr.md
@@ -0,0 +1,139 @@
+
+*This model was released on 2024-04-05 and added to Hugging Face Transformers on 2026-01-23.*
+
+
+
+

+
+
+
+# RF-DETR
+
+[RF-DETR](https://huggingface.co/papers/2407.17140) proposes a Receptive Field Detection Transformer (DETR) architecture
+designed to compete with and surpass the dominant YOLO series for real-time object detection. It achieves a new
+state-of-the-art balance between speed (latency) and accuracy (mAP) by combining recent transformer advances with
+efficient design choices.
+
+The RF-DETR architecture is characterized by its simple and efficient structure: a DINOv2 Backbone, a Projector, and a
+shallow DETR Decoder.
+It enhances the DETR architecture for efficiency and speed using the following core modifications:
+
+1. **DINOv2 Backbone**: Uses a powerful DINOv2 backbone for robust feature extraction.
+2. **Group DETR Training**: Utilizes Group-Wise One-to-Many Assignment during training to accelerate convergence.
+3. **Richer Input**: Aggregates multi-level features from the backbone and uses a C2f Projector (similarly to YOLOv8) to
+ pass multi-scale features.
+4. **Faster Decoder**: Employs a shallow 3-layer DETR decoder with deformable cross-attention for lower latency.
+5. **Optimized Queries**: Uses a mixed-query scheme combining learnable content queries and generated spatial queries.
+
+You can find all the available RF-DETR checkpoints under the [stevenbucaille](https://huggingface.co/stevenbucaille)
+organization.
+The original code can be found [here](https://github.com/roboflow/rf-detr).
+
+> [!TIP]
+> This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
+>
+> Click on the RF-DETR models in the right sidebar for more examples of how to apply RF-DETR to different object
+> detection tasks.
+
+
+The example below demonstrates how to perform object detection with the [`Pipeline`] and the [`AutoModel`] class.
+
+
+
+
+```python
+from transformers import pipeline
+import torch
+
+pipeline = pipeline(
+ "object-detection",
+ model="stevenbucaille/rfdetr_small_60e_coco",
+ dtype=torch.float16,
+ device_map=0
+)
+
+pipeline("http://images.cocodataset.org/val2017/000000039769.jpg")
+```
+
+
+
+
+```python
+from transformers import AutoImageProcessor, AutoModelForObjectDetection
+from PIL import Image
+import requests
+import torch
+
+url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+image = Image.open(requests.get(url, stream=True).raw)
+
+image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small")
+model = AutoModelForObjectDetection.from_pretrained("stevenbucaille/rfdetr_small")
+
+# prepare image for the model
+inputs = image_processor(images=image, return_tensors="pt")
+
+with torch.no_grad():
+ outputs = model(**inputs)
+
+results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)
+
+for result in results:
+ for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
+ score, label = score.item(), label_id.item()
+ box = [round(i, 2) for i in box.tolist()]
+ print(f"{model.config.id2label[label]}: {score:.2f} {box}")
+```
+
+
+
+
+## Resources
+
+
+- Scripts for finetuning [`RfDetrForObjectDetection`] with [`Trainer`]
+ or [Accelerate](https://huggingface.co/docs/accelerate/index) can be
+ found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
+- See also: [Object detection task guide](../tasks/object_detection).
+
+## RfDetrConfig
+
+[[autodoc]] RfDetrConfig
+
+## RfDetrDinov2Config
+
+[[autodoc]] RfDetrDinov2Config
+
+## RfDetrModel
+
+[[autodoc]] RfDetrModel
+ - forward
+
+## RfDetrForObjectDetection
+
+[[autodoc]] RfDetrForObjectDetection
+ - forward
+
+## RfDetrForInstanceSegmentation
+
+[[autodoc]] RfDetrForInstanceSegmentation
+ - forward
+
+## RfDetrDinov2Backbone
+
+[[autodoc]] RfDetrDinov2Backbone
+ - forward
diff --git a/src/transformers/loss/loss_lw_detr.py b/src/transformers/loss/loss_lw_detr.py
index 73844fddae29..b346416b1ce1 100644
--- a/src/transformers/loss/loss_lw_detr.py
+++ b/src/transformers/loss/loss_lw_detr.py
@@ -343,7 +343,7 @@ def LwDetrForObjectDetectionLoss(
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
- weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
+ weight_dict = {"loss_ce": config.class_loss_coefficient, "loss_bbox": config.bbox_loss_coefficient}
weight_dict["loss_giou"] = config.giou_loss_coefficient
if config.auxiliary_loss:
aux_weight_dict = {}
diff --git a/src/transformers/loss/loss_rf_detr.py b/src/transformers/loss/loss_rf_detr.py
new file mode 100644
index 000000000000..a597f439b318
--- /dev/null
+++ b/src/transformers/loss/loss_rf_detr.py
@@ -0,0 +1,483 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .. import requires_backends
+from ..image_transforms import center_to_corners_format
+from ..utils import is_accelerate_available, is_scipy_available, is_vision_available
+from .loss_for_object_detection import (
+ dice_loss,
+ generalized_box_iou,
+)
+from .loss_lw_detr import LwDetrImageLoss
+
+
+if is_vision_available():
+ pass
+
+if is_scipy_available():
+ from scipy.optimize import linear_sum_assignment
+
+if is_accelerate_available():
+ from accelerate import PartialState
+ from accelerate.utils import reduce
+
+
+def sigmoid_cross_entropy_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ num_masks: float,
+):
+ """
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ Returns:
+ Loss tensor
+ """
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ loss = loss.mean(1).sum() / num_masks
+ return loss
+
+
+def point_sample(input, point_coords, **kwargs):
+ """
+ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
+ Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
+ [0, 1] x [0, 1] square.
+
+ Args:
+ input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
+ point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
+ [0, 1] x [0, 1] normalized point coordinates.
+
+ Returns:
+ output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
+ features for points in `point_coords`. The features are obtained via bilinear
+ interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
+ """
+ add_dim = False
+ if point_coords.dim() == 3:
+ add_dim = True
+ point_coords = point_coords.unsqueeze(2)
+ output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
+ if add_dim:
+ output = output.squeeze(3)
+ return output
+
+
+def get_uncertain_point_coords_with_randomness(
+ coarse_logits, uncertainty_func, num_points, oversample_ratio=3, importance_sample_ratio=0.75
+):
+ """
+ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
+ are calculated for each point using 'uncertainty_func' function that takes point's logit
+ prediction as input.
+ See PointRend paper for details.
+
+ Args:
+ coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
+ class-specific or class-agnostic prediction.
+ uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
+ contains logit predictions for P points and returns their uncertainties as a Tensor of
+ shape (N, 1, P).
+ num_points (int): The number of points P to sample.
+ oversample_ratio (int): Oversampling parameter.
+ importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
+ sampled points.
+ """
+ assert oversample_ratio >= 1
+ assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
+ num_boxes = coarse_logits.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
+ point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
+ # It is crucial to calculate uncertainty based on the sampled prediction value for the points.
+ # Calculating uncertainties of the coarse predictions first and sampling them for points leads
+ # to incorrect results.
+ # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
+ # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
+ # However, if we calculate uncertainties for the coarse predictions first,
+ # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
+ point_uncertainties = uncertainty_func(point_logits)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
+ if num_random_points > 0:
+ point_coords = torch.cat(
+ [
+ point_coords,
+ torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
+ ],
+ dim=1,
+ )
+ return point_coords
+
+
+def calculate_uncertainty(logits):
+ """
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
+ foreground class in `classes`.
+ Args:
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
+ class-agnostic, where R is the total number of predicted masks in all images and C is
+ the number of foreground classes. The values are logits.
+ Returns:
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
+ the most uncertain locations having the highest uncertainty score.
+ """
+ assert logits.shape[1] == 1
+ gt_class_logits = logits.clone()
+ return -(torch.abs(gt_class_logits))
+
+
+def batch_sigmoid_cross_entropy_loss(inputs: torch.Tensor, targets: torch.Tensor):
+ """
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ Returns:
+ Loss tensor
+ """
+ hw = inputs.shape[1]
+
+ pos = F.binary_cross_entropy_with_logits(inputs, torch.ones_like(inputs), reduction="none")
+ neg = F.binary_cross_entropy_with_logits(inputs, torch.zeros_like(inputs), reduction="none")
+
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum("nc,mc->nm", neg, (1 - targets))
+
+ return loss / hw
+
+
+def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss
+
+
+class RfDetrHungarianMatcher(nn.Module):
+ def __init__(
+ self,
+ class_cost: float = 1,
+ bbox_cost: float = 1,
+ giou_cost: float = 1,
+ mask_point_sample_ratio: int = 16,
+ cost_mask_class_cost: float = 1,
+ cost_mask_dice_cost: float = 1,
+ ):
+ super().__init__()
+ requires_backends(self, ["scipy"])
+
+ self.class_cost = class_cost
+ self.bbox_cost = bbox_cost
+ self.giou_cost = giou_cost
+ if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
+ raise ValueError("All costs of the Matcher can't be 0")
+
+ self.mask_point_sample_ratio = mask_point_sample_ratio
+ self.cost_mask_class = cost_mask_class_cost
+ self.cost_mask_dice = cost_mask_dice_cost
+
+ @torch.no_grad()
+ def forward(self, outputs, targets, group_detr):
+ """
+ Differences:
+ - out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax
+ - class_cost uses alpha and gamma
+ """
+ batch_size, num_queries = outputs["logits"].shape[:2]
+
+ # We flatten to compute the cost matrices in a batch
+ out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
+ out_masks = outputs["pred_masks"].flatten(0, 1) # [batch_size * num_queries, H, W]
+
+ # Also concat the target labels and boxes
+ target_ids = torch.cat([v["class_labels"] for v in targets])
+ target_bbox = torch.cat([v["boxes"] for v in targets])
+ target_masks = torch.cat([v["masks"] for v in targets])
+
+ # Compute the classification cost.
+ alpha = 0.25
+ gamma = 2.0
+ neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+ class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
+
+ # Compute the L1 cost between boxes, cdist only supports float32
+ dtype = out_bbox.dtype
+ out_bbox = out_bbox.to(torch.float32)
+ target_bbox = target_bbox.to(torch.float32)
+ bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
+ bbox_cost = bbox_cost.to(dtype)
+
+ # Compute the giou cost between boxes
+ giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
+
+ # Compute mask cost
+ num_points = out_masks.shape[-2] * out_masks.shape[-1] // self.mask_point_sample_ratio
+ tgt_masks = target_masks.to(out_masks.dtype)
+ torch.manual_seed(0) # TODO REMOVE
+ point_coords = torch.rand(1, num_points, 2, device=out_masks.device)
+ pred_masks_logits = point_sample(
+ out_masks.unsqueeze(1), point_coords.repeat(out_masks.shape[0], 1, 1), align_corners=False
+ ).squeeze(1)
+ tgt_masks_flat = point_sample(
+ tgt_masks.unsqueeze(1), point_coords.repeat(tgt_masks.shape[0], 1, 1), align_corners=False, mode="nearest"
+ ).squeeze(1)
+
+ cost_mask_class = batch_sigmoid_cross_entropy_loss(pred_masks_logits, tgt_masks_flat)
+ cost_mask_dice = batch_dice_loss(pred_masks_logits, tgt_masks_flat)
+
+ # Final cost matrix
+ cost_matrix = (
+ self.bbox_cost * bbox_cost
+ + self.class_cost * class_cost
+ + self.giou_cost * giou_cost
+ + self.cost_mask_class * cost_mask_class
+ + self.cost_mask_dice * cost_mask_dice
+ )
+ cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
+
+ # we assume any good match will not cause NaN or Inf, so we replace them with a large value
+ max_cost = cost_matrix.max() if cost_matrix.numel() > 0 else 0
+ cost_matrix[cost_matrix.isinf() | cost_matrix.isnan()] = max_cost * 2
+
+ # Hungarian matching
+ sizes = [len(v["masks"]) for v in targets]
+ indices = []
+ group_num_queries = num_queries // group_detr
+ cost_matrix_list = cost_matrix.split(group_num_queries, dim=1)
+ for group_id in range(group_detr):
+ group_cost_matrix = cost_matrix_list[group_id]
+ group_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(group_cost_matrix.split(sizes, -1))]
+ if group_id == 0:
+ indices = group_indices
+ else:
+ indices = [
+ (
+ np.concatenate([indice1[0], indice2[0] + group_num_queries * group_id]),
+ np.concatenate([indice1[1], indice2[1]]),
+ )
+ for indice1, indice2 in zip(indices, group_indices)
+ ]
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+class RfDetrImageLoss(LwDetrImageLoss):
+ def __init__(self, matcher, num_classes, focal_alpha, losses, group_detr, mask_point_sample_ratio):
+ nn.Module.__init__(self)
+ self.matcher = matcher
+ self.num_classes = num_classes
+ self.focal_alpha = focal_alpha
+ self.losses = losses
+ self.group_detr = group_detr
+ self.mask_point_sample_ratio = mask_point_sample_ratio
+
+ def loss_masks(self, outputs, targets, indices, num_boxes):
+ """
+ Compute the losses related to the masks: the focal loss and the dice loss.
+
+ Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
+ """
+ if "pred_masks" not in outputs:
+ raise KeyError("No predicted masks found in outputs")
+
+ source_idx = self._get_source_permutation_idx(indices)
+ source_masks = outputs["pred_masks"]
+ source_masks = source_masks[source_idx]
+ if source_masks.numel() == 0:
+ return {
+ "loss_mask_ce": source_masks.sum(),
+ "loss_mask_dice": source_masks.sum(),
+ }
+
+ # gather matched target masks
+ target_masks = torch.cat([t["masks"][j] for t, (_, j) in zip(targets, indices)], dim=0)
+
+ source_masks = source_masks.unsqueeze(1)
+ target_masks = target_masks.unsqueeze(1).float()
+
+ num_points = max(
+ source_masks.shape[-2], source_masks.shape[-2] * source_masks.shape[-1] // self.mask_point_sample_ratio
+ )
+
+ with torch.no_grad():
+ # sample point_coords
+ point_coords = get_uncertain_point_coords_with_randomness(
+ source_masks,
+ lambda logits: calculate_uncertainty(logits),
+ num_points,
+ 3,
+ 0.75,
+ )
+ # get gt labels
+ point_labels = point_sample(
+ target_masks,
+ point_coords,
+ align_corners=False,
+ mode="nearest",
+ ).squeeze(1)
+
+ point_logits = point_sample(
+ source_masks,
+ point_coords,
+ align_corners=False,
+ ).squeeze(1)
+
+ losses = {
+ "loss_mask_ce": sigmoid_cross_entropy_loss(point_logits, point_labels, num_boxes),
+ "loss_mask_dice": dice_loss(point_logits, point_labels, num_boxes),
+ }
+ return losses
+
+ def forward(self, outputs, targets):
+ """
+ This performs the loss computation.
+
+ Args:
+ outputs (`dict`, *optional*):
+ Dictionary of tensors, see the output specification of the model for the format.
+ targets (`list[dict]`, *optional*):
+ List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
+ losses applied, see each loss' doc.
+ """
+ group_detr = self.group_detr if self.training else 1
+ outputs_without_aux_and_enc = {
+ k: v for k, v in outputs.items() if k != "enc_outputs" and k != "auxiliary_outputs"
+ }
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices = self.matcher(outputs_without_aux_and_enc, targets, group_detr)
+
+ # Compute the average number of target boxes across all nodes, for normalization purposes
+ num_boxes = sum(len(t["class_labels"]) for t in targets)
+ num_boxes = num_boxes * group_detr
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+ world_size = 1
+ if is_accelerate_available():
+ if PartialState._shared_state != {}:
+ num_boxes = reduce(num_boxes)
+ world_size = PartialState().num_processes
+ num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "auxiliary_outputs" in outputs:
+ for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
+ indices = self.matcher(auxiliary_outputs, targets, group_detr)
+ for loss in self.losses:
+ l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if "enc_outputs" in outputs:
+ enc_outputs = outputs["enc_outputs"]
+ indices = self.matcher(enc_outputs, targets, group_detr=group_detr)
+ for loss in self.losses:
+ l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes)
+ l_dict = {k + "_enc": v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+
+def _set_aux_loss(outputs_class, outputs_coord, outputs_masks):
+ return [
+ {"logits": a, "pred_boxes": b, "pred_masks": c}
+ for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], outputs_masks[:-1])
+ ]
+
+
+def RfDetrForSegmentationLoss(
+ logits,
+ labels,
+ device,
+ pred_boxes,
+ pred_masks,
+ config,
+ outputs_class=None,
+ outputs_coord=None,
+ outputs_masks=None,
+ enc_outputs_class=None,
+ enc_outputs_coord=None,
+ enc_outputs_masks=None,
+ **kwargs,
+):
+ # First: create the matcher
+ matcher = RfDetrHungarianMatcher(
+ class_cost=config.class_cost,
+ bbox_cost=config.bbox_cost,
+ giou_cost=config.giou_cost,
+ mask_point_sample_ratio=config.mask_point_sample_ratio,
+ cost_mask_class_cost=config.mask_class_loss_coefficient,
+ cost_mask_dice_cost=config.mask_dice_loss_coefficient,
+ )
+ # Second: create the criterion
+ losses = ["labels", "boxes", "cardinality", "masks"]
+ criterion = RfDetrImageLoss(
+ matcher=matcher,
+ num_classes=config.num_labels,
+ focal_alpha=config.focal_alpha,
+ losses=losses,
+ group_detr=config.group_detr,
+ mask_point_sample_ratio=config.mask_point_sample_ratio,
+ )
+ criterion.to(device)
+ # Third: compute the losses, based on outputs and labels
+ outputs_loss = {}
+ auxiliary_outputs = None
+ outputs_loss["logits"] = logits
+ outputs_loss["pred_boxes"] = pred_boxes
+ outputs_loss["pred_masks"] = pred_masks
+ outputs_loss["enc_outputs"] = {
+ "logits": enc_outputs_class,
+ "pred_boxes": enc_outputs_coord,
+ "pred_masks": enc_outputs_masks,
+ }
+ if config.auxiliary_loss:
+ auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord, outputs_masks)
+ outputs_loss["auxiliary_outputs"] = auxiliary_outputs
+ loss_dict = criterion(outputs_loss, labels)
+ # Fourth: compute total loss, as a weighted sum of the various losses
+ weight_dict = {"loss_ce": config.class_loss_coefficient, "loss_bbox": config.bbox_loss_coefficient}
+ weight_dict["loss_giou"] = config.giou_loss_coefficient
+ weight_dict["loss_mask_ce"] = config.mask_class_loss_coefficient
+ weight_dict["loss_mask_dice"] = config.mask_dice_loss_coefficient
+ if config.auxiliary_loss:
+ aux_weight_dict = {}
+ for i in range(config.decoder_layers - 1):
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+ aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+ loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict)
+ return loss, loss_dict, auxiliary_outputs
diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py
index df269477e9ec..8b734461dbdc 100644
--- a/src/transformers/loss/loss_utils.py
+++ b/src/transformers/loss/loss_utils.py
@@ -22,6 +22,7 @@
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
from .loss_lw_detr import LwDetrForObjectDetectionLoss
+from .loss_rf_detr import RfDetrForSegmentationLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss
@@ -165,4 +166,6 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
"DFineForObjectDetection": DFineForObjectDetectionLoss,
"CsmForConditionalGeneration": ForCausalLMLoss,
"LwDetrForObjectDetection": LwDetrForObjectDetectionLoss,
+ "RfDetrForObjectDetection": LwDetrForObjectDetectionLoss,
+ "RfDetrForInstanceSegmentation": RfDetrForSegmentationLoss,
}
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 89361203d5ef..0badb16c8bb5 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -327,6 +327,7 @@
from .regnet import *
from .rembert import *
from .resnet import *
+ from .rf_detr import *
from .roberta import *
from .roberta_prelayernorm import *
from .roc_bert import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 7909e3834aa9..44acb4c6f360 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -366,6 +366,8 @@
("regnet", "RegNetConfig"),
("rembert", "RemBertConfig"),
("resnet", "ResNetConfig"),
+ ("rf_detr", "RfDetrConfig"),
+ ("rf_detr_dinov2", "RfDetrDinov2Config"),
("roberta", "RobertaConfig"),
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
("roc_bert", "RoCBertConfig"),
@@ -844,6 +846,8 @@
("regnet", "RegNet"),
("rembert", "RemBERT"),
("resnet", "ResNet"),
+ ("rf_detr", "RF-DETR"),
+ ("rf_detr_dinov2", "RF-DETR-DINOv2"),
("roberta", "RoBERTa"),
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
("roc_bert", "RoCBert"),
@@ -1010,6 +1014,7 @@
("smolvlm_vision", "smolvlm"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
+ ("rf_detr_dinov2", "rf_detr"),
("granitevision", "llava_next"),
("internvl_vision", "internvl"),
("qwen2_5_vl_text", "qwen2_5_vl"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index e5fe9f4400eb..90aa089f3ec2 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -179,6 +179,7 @@
("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
+ ("rf_detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
("sam2", (None, "Sam2ImageProcessorFast")),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index d9370d78e736..5938b794d36d 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -355,6 +355,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("regnet", "RegNetModel"),
("rembert", "RemBertModel"),
("resnet", "ResNetModel"),
+ ("rf_detr", "RfDetrModel"),
("roberta", "RobertaModel"),
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
("roc_bert", "RoCBertModel"),
@@ -865,6 +866,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
# Model for Instance Segmentation mapping
# MaskFormerForInstanceSegmentation can be removed from this mapping in v5
("maskformer", "MaskFormerForInstanceSegmentation"),
+ ("rf_detr", "RfDetrForInstanceSegmentation"),
]
)
@@ -1030,6 +1032,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("deformable_detr", "DeformableDetrForObjectDetection"),
("detr", "DetrForObjectDetection"),
("lw_detr", "LwDetrForObjectDetection"),
+ ("rf_detr", "RfDetrForObjectDetection"),
("rt_detr", "RTDetrForObjectDetection"),
("rt_detr_v2", "RTDetrV2ForObjectDetection"),
("table-transformer", "TableTransformerForObjectDetection"),
@@ -1598,6 +1601,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("pixio", "PixioBackbone"),
("pvt_v2", "PvtV2Backbone"),
("resnet", "ResNetBackbone"),
+ ("rf_detr_dinov2", "RfDetrDinov2Backbone"),
("rt_detr_resnet", "RTDetrResNetBackbone"),
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
diff --git a/src/transformers/models/lw_detr/configuration_lw_detr.py b/src/transformers/models/lw_detr/configuration_lw_detr.py
index 3e90410ccd7c..6a89fe8a2656 100644
--- a/src/transformers/models/lw_detr/configuration_lw_detr.py
+++ b/src/transformers/models/lw_detr/configuration_lw_detr.py
@@ -233,8 +233,8 @@ class LwDetrConfig(PreTrainedConfig):
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
giou_cost (`float`, *optional*, defaults to 2):
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
- mask_loss_coefficient (`float`, *optional*, defaults to 1):
- Relative weight of the Focal loss in the panoptic segmentation loss.
+ class_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the Cross Entropy loss in the object detection loss.
dice_loss_coefficient (`float`, *optional*, defaults to 1):
Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
@@ -297,7 +297,7 @@ def __init__(
class_cost=2,
bbox_cost=5,
giou_cost=2,
- mask_loss_coefficient=1,
+ class_loss_coefficient=1,
dice_loss_coefficient=1,
bbox_loss_coefficient=5,
giou_loss_coefficient=2,
@@ -362,6 +362,7 @@ def __init__(
self.bbox_cost = bbox_cost
self.giou_cost = giou_cost
# Loss coefficients
+ self.class_loss_coefficient = class_loss_coefficient
self.dice_loss_coefficient = dice_loss_coefficient
self.bbox_loss_coefficient = bbox_loss_coefficient
self.giou_loss_coefficient = giou_loss_coefficient
diff --git a/src/transformers/models/lw_detr/modular_lw_detr.py b/src/transformers/models/lw_detr/modular_lw_detr.py
index ac824e990b11..3eead582a85c 100644
--- a/src/transformers/models/lw_detr/modular_lw_detr.py
+++ b/src/transformers/models/lw_detr/modular_lw_detr.py
@@ -262,8 +262,8 @@ class LwDetrConfig(PreTrainedConfig):
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
giou_cost (`float`, *optional*, defaults to 2):
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
- mask_loss_coefficient (`float`, *optional*, defaults to 1):
- Relative weight of the Focal loss in the panoptic segmentation loss.
+ class_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the Cross Entropy loss in the object detection loss.
dice_loss_coefficient (`float`, *optional*, defaults to 1):
Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
@@ -326,7 +326,7 @@ def __init__(
class_cost=2,
bbox_cost=5,
giou_cost=2,
- mask_loss_coefficient=1,
+ class_loss_coefficient=1,
dice_loss_coefficient=1,
bbox_loss_coefficient=5,
giou_loss_coefficient=2,
@@ -391,6 +391,7 @@ def __init__(
self.bbox_cost = bbox_cost
self.giou_cost = giou_cost
# Loss coefficients
+ self.class_loss_coefficient = class_loss_coefficient
self.dice_loss_coefficient = dice_loss_coefficient
self.bbox_loss_coefficient = bbox_loss_coefficient
self.giou_loss_coefficient = giou_loss_coefficient
diff --git a/src/transformers/models/rf_detr/__init__.py b/src/transformers/models/rf_detr/__init__.py
new file mode 100644
index 000000000000..730e55bdf5ed
--- /dev/null
+++ b/src/transformers/models/rf_detr/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_rf_detr import *
+ from .modeling_rf_detr import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/rf_detr/configuration_rf_detr.py b/src/transformers/models/rf_detr/configuration_rf_detr.py
new file mode 100644
index 000000000000..789f43cff535
--- /dev/null
+++ b/src/transformers/models/rf_detr/configuration_rf_detr.py
@@ -0,0 +1,414 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/rf_detr/modular_rf_detr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_rf_detr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PreTrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class RfDetrDinov2Config(BackboneConfigMixin, PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`RfDetrDinov2Model`]. It is used to instantiate an
+ RfDetrDinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2
+ [facebook/dinov2-base](https://huggingface.co/facebook/dinov2-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+ use_mask_token (`bool`, *optional*, defaults to `True`):
+ Whether to use mask_token in embeddings.
+ num_windows (`int`, *optional*, defaults to 4):
+ Number of windows to use for windowed attention. If 1, no windowed attention is used.
+ Example:
+
+ ```python
+ >>> from transformers import RfDetrDinov2Config, RfDetrDinov2Backbone
+
+ >>> # Initializing a RfDetrDinov2 base style configuration
+ >>> configuration = RfDetrDinov2Config()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = RfDetrDinov2Backbone(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "rf_detr_dinov2"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=14,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ use_mask_token=True,
+ num_windows: int = 4,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+ self.use_mask_token = use_mask_token
+
+ self.num_windows = num_windows
+ window_block_indexes = set(range(self._out_indices[-1] + 1))
+ window_block_indexes.difference_update(self._out_indices)
+ window_block_indexes = list(window_block_indexes)
+ self.window_block_indexes = window_block_indexes
+
+
+class RfDetrConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`RfDetrModel`]. It is used to instantiate
+ a LW-DETR model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LW-DETR
+ [stevenbucaille/RfDetr_small_60e_coco](https://huggingface.co/stevenbucaille/RfDetr_small_60e_coco) architecture.
+
+ LW-DETR (Lightweight Detection Transformer) is a transformer-based object detection model designed for real-time
+ detection tasks. It replaces traditional CNN-based detectors like YOLO with a more efficient transformer architecture
+ that achieves competitive performance while being computationally lightweight.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`PretrainedConfig` or `dict`, *optional*):
+ The configuration of the backbone model. If not provided, will default to `RfDetrDinov2Config`
+ with a small ViT architecture optimized for detection tasks.
+ projector_scale_factors (`list[float]`, *optional*, defaults to `[]`):
+ Scale factors for the feature pyramid network. Each scale factor determines the resolution of features
+ at different levels. Supported values are 0.5, 1.0, and 2.0.
+ hidden_expansion (`float`, *optional*, defaults to 0.5):
+ Expansion factor for hidden dimensions in the projector layers.
+ c2f_num_blocks (`int`, *optional*, defaults to 3):
+ Number of blocks in the C2F layer.
+ activation_function (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function in the projector. Supported values are `"silu"`, `"relu"`, `"gelu"`.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon value for layer normalization layers.
+ d_model (`int`, *optional*, defaults to 256):
+ Dimension of the model layers and the number of expected features in the decoder inputs.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ decoder_ffn_dim (`int`, *optional*, defaults to 2048):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ decoder_n_points (`int`, *optional*, defaults to 4):
+ The number of sampled keys in each feature level for each attention head in the decoder.
+ decoder_layers (`int`, *optional*, defaults to 3):
+ Number of decoder layers in the transformer.
+ decoder_self_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the decoder self-attention.
+ decoder_cross_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the decoder cross-attention.
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function in the decoder. Supported values are `"relu"`, `"silu"`, `"gelu"`.
+ num_queries (`int`, *optional*, defaults to 300):
+ Number of object queries, i.e. detection slots. This is the maximal number of objects
+ [`RfDetrModel`] can detect in a single image.
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add bias to the attention layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ group_detr (`int`, *optional*, defaults to 13):
+ Number of groups for Group DETR attention mechanism, which helps reduce computational complexity.
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ disable_custom_kernels (`bool`, *optional*, defaults to `True`):
+ Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
+ kernels are not supported by PyTorch ONNX export.
+ class_cost (`float`, *optional*, defaults to 2):
+ Relative weight of the classification error in the Hungarian matching cost.
+ bbox_cost (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
+ giou_cost (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
+ class_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the classification loss in the Hungarian matching cost.
+ mask_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the Focal loss in the instance segmentation mask loss.
+ dice_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the DICE/F-1 loss in the object detection loss.
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 bounding box loss in the object detection loss.
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss in the object detection loss.
+ eos_coefficient (`float`, *optional*, defaults to 0.1):
+ Relative classification weight of the 'no-object' class in the object detection loss.
+ focal_alpha (`float`, *optional*, defaults to 0.25):
+ Alpha parameter in the focal loss.
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
+ mask_point_sample_ratio (`int`, *optional*, defaults to 16):
+ The ratio of points to sample for the mask loss calculation.
+ mask_downsample_ratio (`int`, *optional*, defaults to 4):
+ The downsample ratio for the segmentation masks compared to the input image resolution.
+ mask_class_loss_coefficient (`float`, *optional*, defaults to 5.0):
+ Relative weight of the Focal loss in the instance segmentation loss.
+ mask_dice_loss_coefficient (`float`, *optional*, defaults to 5.0):
+ Relative weight of the DICE/F-1 loss in the instance segmentation loss.
+ segmentation_head_activation_function (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the segmentation head. Supported values are `"relu"`, `"silu"`, `"gelu"`.
+ Examples:
+
+ ```python
+ >>> from transformers import RfDetrConfig, RfDetrModel
+
+ >>> # Initializing a LW-DETR stevenbucaille/RfDetr_small_60e_coco style configuration
+ >>> configuration = RfDetrConfig()
+
+ >>> # Initializing a model (with random weights) from the stevenbucaille/RfDetr_small_60e_coco style configuration
+ >>> model = RfDetrModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "rf_detr"
+ sub_configs = {"backbone_config": AutoConfig}
+
+ def __init__(
+ self,
+ # backbone
+ backbone_config=None,
+ # projector
+ projector_scale_factors: list[float] = [],
+ hidden_expansion=0.5,
+ c2f_num_blocks=3,
+ activation_function="silu",
+ layer_norm_eps=1e-5,
+ # decoder
+ d_model=256,
+ dropout=0.1,
+ decoder_ffn_dim=2048,
+ decoder_n_points=4,
+ decoder_layers: int = 3,
+ decoder_self_attention_heads: int = 8,
+ decoder_cross_attention_heads: int = 16,
+ decoder_activation_function="relu",
+ # model
+ num_queries=300,
+ attention_bias=True,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ group_detr: int = 13,
+ init_std=0.02,
+ disable_custom_kernels=True,
+ # loss
+ class_cost=2,
+ bbox_cost=5,
+ giou_cost=2,
+ class_loss_coefficient=1,
+ mask_loss_coefficient=1,
+ dice_loss_coefficient=1,
+ bbox_loss_coefficient=5,
+ giou_loss_coefficient=2,
+ eos_coefficient=0.1,
+ focal_alpha=0.25,
+ auxiliary_loss=True,
+ mask_point_sample_ratio=16,
+ # segmentation
+ mask_downsample_ratio=4,
+ mask_class_loss_coefficient=5.0,
+ mask_dice_loss_coefficient=5.0,
+ segmentation_head_activation_function="gelu",
+ **kwargs,
+ ):
+ self.layer_norm_eps = layer_norm_eps
+
+ # backbone
+ if backbone_config is None:
+ logger.info(
+ "`backbone_config` is `None`. Initializing the config with the default `RfDetrDinov2` backbone."
+ )
+ backbone_config = RfDetrDinov2Config(
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.0,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-06,
+ layerscale_value=1.0,
+ mlp_ratio=4,
+ num_attention_heads=6,
+ num_channels=3,
+ num_hidden_layers=12,
+ qkv_bias=True,
+ use_swiglu_ffn=False,
+ out_features=["stage2", "stage5", "stage8", "stage11"],
+ hidden_size=384,
+ patch_size=14,
+ num_windows=4,
+ num_register_tokens=0,
+ image_size=518,
+ **kwargs,
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.pop("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ self.backbone_config = backbone_config
+
+ # projector
+ self.projector_scale_factors = projector_scale_factors
+ for scale in projector_scale_factors:
+ if scale not in [0.5, 1.0, 2.0]:
+ raise ValueError(f"Unsupported scale factor: {scale}")
+ self.projector_in_channels = [d_model] * len(projector_scale_factors)
+ self.projector_out_channels = d_model
+ self.activation_function = activation_function
+ self.hidden_expansion = hidden_expansion
+ self.c2f_num_blocks = c2f_num_blocks
+ # decoder
+ self.d_model = d_model
+ self.dropout = dropout
+ self.num_queries = num_queries
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.num_feature_levels = len(self.projector_scale_factors)
+ self.decoder_n_points = decoder_n_points
+ self.decoder_layers = decoder_layers
+ self.decoder_activation_function = decoder_activation_function
+ self.decoder_self_attention_heads = decoder_self_attention_heads
+ self.decoder_cross_attention_heads = decoder_cross_attention_heads
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ # model
+ self.init_std = init_std
+ self.group_detr = group_detr
+ # Loss
+ self.auxiliary_loss = auxiliary_loss
+ # Hungarian matcher
+ self.class_cost = class_cost
+ self.bbox_cost = bbox_cost
+ self.giou_cost = giou_cost
+ # Loss coefficients
+ self.class_loss_coefficient = class_loss_coefficient
+ self.mask_loss_coefficient = mask_loss_coefficient
+ self.dice_loss_coefficient = dice_loss_coefficient
+ self.bbox_loss_coefficient = bbox_loss_coefficient
+ self.giou_loss_coefficient = giou_loss_coefficient
+ self.mask_class_loss_coefficient = mask_class_loss_coefficient
+ self.mask_dice_loss_coefficient = mask_dice_loss_coefficient
+ self.eos_coefficient = eos_coefficient
+ self.focal_alpha = focal_alpha
+ self.disable_custom_kernels = disable_custom_kernels
+ self.mask_point_sample_ratio = mask_point_sample_ratio
+ # segmentation
+ self.mask_downsample_ratio = mask_downsample_ratio
+ self.segmentation_head_activation_function = segmentation_head_activation_function
+ super().__init__(**kwargs)
+
+
+__all__ = ["RfDetrConfig", "RfDetrDinov2Config"]
diff --git a/src/transformers/models/rf_detr/convert_rf_detr_weights_to_hf.py b/src/transformers/models/rf_detr/convert_rf_detr_weights_to_hf.py
new file mode 100644
index 000000000000..f6c9f86f0a7a
--- /dev/null
+++ b/src/transformers/models/rf_detr/convert_rf_detr_weights_to_hf.py
@@ -0,0 +1,999 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert RF-DETR checkpoints to transformers format."""
+
+import argparse
+
+import requests
+import torch
+from PIL import Image
+from torchvision import transforms
+
+from transformers import (
+ DetrImageProcessor,
+ RfDetrConfig,
+ RfDetrForInstanceSegmentation,
+ RfDetrForObjectDetection,
+)
+from transformers.core_model_loading import (
+ Chunk,
+ WeightConverter,
+ WeightRenaming,
+ convert_and_load_state_dict_in_model,
+)
+
+
+# Mapping of model names to their checkpoint files
+HOSTED_MODELS = {
+ "rf-detr-base": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth",
+ # base-2 is a less converged model that may be better for finetuning but worse for inference
+ "rf-detr-base-2": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth",
+ "rf-detr-nano": "https://storage.googleapis.com/rfdetr/nano_coco/checkpoint_best_regular.pth",
+ "rf-detr-small": "https://storage.googleapis.com/rfdetr/small_coco/checkpoint_best_regular.pth",
+ "rf-detr-medium": "https://storage.googleapis.com/rfdetr/medium_coco/checkpoint_best_regular.pth",
+ "rf-detr-large-deprecated": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth",
+ "rf-detr-large": "https://storage.googleapis.com/rfdetr/rf-detr-large-2026.pth",
+ "rf-detr-seg-preview": "https://storage.googleapis.com/rfdetr/rf-detr-seg-preview.pt",
+ "rf-detr-seg-nano": "https://storage.googleapis.com/rfdetr/rf-detr-seg-n-ft.pth",
+ "rf-detr-seg-small": "https://storage.googleapis.com/rfdetr/rf-detr-seg-s-ft.pth",
+ "rf-detr-seg-medium": "https://storage.googleapis.com/rfdetr/rf-detr-seg-m-ft.pth",
+ "rf-detr-seg-large": "https://storage.googleapis.com/rfdetr/rf-detr-seg-l-ft.pth",
+ "rf-detr-seg-xlarge": "https://storage.googleapis.com/rfdetr/rf-detr-seg-xl-ft.pth",
+ "rf-detr-seg-xxlarge": "https://storage.googleapis.com/rfdetr/rf-detr-seg-2xl-ft.pth",
+}
+
+# Model configurations for different sizes
+BACKBONE_CONFIGS = {
+ "rf-detr-nano": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 16,
+ "num_windows": 2,
+ "image_size": 384,
+ },
+ "rf-detr-small": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 16,
+ "num_windows": 2,
+ "image_size": 512,
+ },
+ "rf-detr-base": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage2", "stage5", "stage8", "stage11"],
+ "hidden_size": 384,
+ "patch_size": 14,
+ "num_windows": 4,
+ "image_size": 518,
+ },
+ "rf-detr-medium": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 16,
+ "num_windows": 2,
+ "image_size": 576,
+ },
+ "rf-detr-large-deprecated": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 12,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage2", "stage5", "stage8", "stage11"],
+ "hidden_size": 768,
+ "patch_size": 14,
+ "num_windows": 4,
+ "image_size": 518,
+ },
+ "rf-detr-large": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 16,
+ "num_windows": 2,
+ "image_size": 704,
+ },
+ "rf-detr-xlarge": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 12,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage2", "stage5", "stage8", "stage11"],
+ "hidden_size": 768,
+ "patch_size": 20,
+ "num_windows": 1,
+ "image_size": 700,
+ },
+ "rf-detr-xxlarge": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 12,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage2", "stage5", "stage8", "stage11"],
+ "hidden_size": 768,
+ "patch_size": 20,
+ "num_windows": 2,
+ "image_size": 880,
+ },
+ "rf-detr-seg-preview": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 12,
+ "num_windows": 2,
+ "image_size": 432,
+ },
+ "rf-detr-seg-nano": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 12,
+ "num_windows": 1,
+ "image_size": 312,
+ },
+ "rf-detr-seg-small": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 12,
+ "num_windows": 2,
+ "image_size": 384,
+ },
+ "rf-detr-seg-medium": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 12,
+ "num_windows": 2,
+ "image_size": 432,
+ },
+ "rf-detr-seg-large": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 12,
+ "num_windows": 2,
+ "image_size": 504,
+ },
+ "rf-detr-seg-xlarge": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 12,
+ "num_windows": 2,
+ "image_size": 624,
+ },
+ "rf-detr-seg-xxlarge": {
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "num_attention_heads": 6,
+ "num_channels": 3,
+ "num_hidden_layers": 12,
+ "qkv_bias": True,
+ "use_swiglu_ffn": False,
+ "out_features": ["stage3", "stage6", "stage9", "stage12"],
+ "hidden_size": 384,
+ "patch_size": 12,
+ "num_windows": 2,
+ "image_size": 768,
+ },
+}
+
+MODEL_CONFIGS = {
+ "rf-detr-nano": {
+ "d_model": 256,
+ "decoder_layers": 2,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 300,
+ },
+ "rf-detr-small": {
+ "d_model": 256,
+ "decoder_layers": 3,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 300,
+ },
+ "rf-detr-base": {
+ "d_model": 256,
+ "decoder_layers": 3,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 300,
+ },
+ "rf-detr-medium": {
+ "d_model": 256,
+ "decoder_layers": 4,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 300,
+ },
+ "rf-detr-large-deprecated": {
+ "d_model": 384,
+ "num_queries": 300,
+ "decoder_layers": 3,
+ "decoder_self_attention_heads": 12,
+ "decoder_cross_attention_heads": 24,
+ "decoder_n_points": 4,
+ "projector_scale_factors": [2.0, 0.5],
+ },
+ "rf-detr-large": {
+ "d_model": 256,
+ "num_queries": 300,
+ "decoder_layers": 4,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ },
+ "rf-detr-xlarge": {
+ "d_model": 512,
+ "num_queries": 300,
+ "decoder_layers": 5,
+ "decoder_self_attention_heads": 16,
+ "decoder_cross_attention_heads": 32,
+ "decoder_n_points": 4,
+ "projector_scale_factors": [2.0, 0.5],
+ },
+ "rf-detr-xxlarge": {
+ "d_model": 512,
+ "num_queries": 300,
+ "decoder_layers": 5,
+ "decoder_self_attention_heads": 16,
+ "decoder_cross_attention_heads": 32,
+ "decoder_n_points": 4,
+ "projector_scale_factors": [2.0, 0.5],
+ },
+ "rf-detr-seg-preview": {
+ "d_model": 256,
+ "decoder_layers": 4,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 200,
+ "class_loss_coefficient": 5.0,
+ },
+ "rf-detr-seg-nano": {
+ "d_model": 256,
+ "decoder_layers": 4,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 100,
+ "class_loss_coefficient": 5.0,
+ },
+ "rf-detr-seg-small": {
+ "d_model": 256,
+ "decoder_layers": 4,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 100,
+ "class_loss_coefficient": 5.0,
+ },
+ "rf-detr-seg-medium": {
+ "d_model": 256,
+ "decoder_layers": 5,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 200,
+ "class_loss_coefficient": 5.0,
+ },
+ "rf-detr-seg-large": {
+ "d_model": 256,
+ "decoder_layers": 5,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 300,
+ "class_loss_coefficient": 5.0,
+ },
+ "rf-detr-seg-xlarge": {
+ "d_model": 256,
+ "decoder_layers": 6,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 300,
+ "class_loss_coefficient": 5.0,
+ },
+ "rf-detr-seg-xxlarge": {
+ "d_model": 256,
+ "decoder_layers": 6,
+ "decoder_self_attention_heads": 8,
+ "decoder_cross_attention_heads": 16,
+ "decoder_n_points": 2,
+ "projector_scale_factors": [1.0],
+ "num_queries": 300,
+ "class_loss_coefficient": 5.0,
+ },
+}
+
+IMAGE_PROCESSORS = {
+ "rf-detr-nano": {
+ "do_resize": True,
+ "size": (384, 384),
+ },
+ "rf-detr-small": {
+ "do_resize": True,
+ "size": (512, 512),
+ },
+ "rf-detr-base": {
+ "do_resize": True,
+ "size": (560, 560),
+ },
+ "rf-detr-medium": {
+ "do_resize": True,
+ "size": (576, 576),
+ },
+ "rf-detr-large-deprecated": {
+ "do_resize": True,
+ "size": (560, 560),
+ },
+ "rf-detr-large": {
+ "do_resize": True,
+ "size": (704, 704),
+ },
+ "rf-detr-xlarge": {
+ "do_resize": True,
+ "size": (560, 560),
+ },
+ "rf-detr-xxlarge": {
+ "do_resize": True,
+ "size": (560, 560),
+ },
+ "rf-detr-seg-preview": {
+ "do_resize": True,
+ "size": (432, 432),
+ },
+ "rf-detr-seg-nano": {
+ "do_resize": True,
+ "size": (312, 312),
+ },
+ "rf-detr-seg-small": {
+ "do_resize": True,
+ "size": (384, 384),
+ },
+ "rf-detr-seg-medium": {
+ "do_resize": True,
+ "size": (432, 432),
+ },
+ "rf-detr-seg-large": {
+ "do_resize": True,
+ "size": (504, 504),
+ },
+ "rf-detr-seg-xlarge": {
+ "do_resize": True,
+ "size": (624, 624),
+ },
+ "rf-detr-seg-xxlarge": {
+ "do_resize": True,
+ "size": (768, 768),
+ },
+}
+
+
+def get_model_config(model_name: str):
+ """Get the appropriate configuration for a given model size."""
+ config = None
+ image_processor_config = None
+ sizes = MODEL_CONFIGS.keys()
+ for size in sizes:
+ if size == model_name:
+ config = MODEL_CONFIGS[size]
+ config["backbone_config"] = BACKBONE_CONFIGS[size]
+ image_processor_config = IMAGE_PROCESSORS[size]
+ break
+
+ # Default to base configuration
+ if config is None:
+ config = MODEL_CONFIGS["base"]
+ config["backbone_config"] = BACKBONE_CONFIGS["base"]
+ image_processor_config = IMAGE_PROCESSORS["base"]
+ config["backbone_config"]["model_type"] = "rf_detr_dinov2"
+
+ if "objects365" in model_name:
+ config["num_labels"] = 366
+ elif "coco" in model_name:
+ config["num_labels"] = 91
+ else:
+ config["num_labels"] = 91
+
+ return config, image_processor_config
+
+
+def get_weight_mapping(
+ rf_detr_config: RfDetrConfig,
+ is_segmentation: bool,
+) -> list[WeightConverter | WeightRenaming]:
+ if is_segmentation:
+ weight_mapping = [
+ # backbone RfDetrConvEncoder
+ WeightRenaming("backbone.0.encoder.encoder", "rf_detr.model.backbone.backbone"),
+ WeightRenaming("backbone.0.projector", "rf_detr.model.backbone.projector"),
+ # RfDetrDecoder
+ WeightRenaming("transformer.decoder", "rf_detr.model.decoder"),
+ # RfDetrForObjectDetection
+ WeightRenaming(r"transformer.enc_out_bbox_embed", r"rf_detr.model.enc_out_bbox_embed"),
+ WeightRenaming(r"transformer.enc_output.(\d+)", r"rf_detr.model.enc_output.\1"),
+ WeightRenaming(r"transformer.enc_output_norm.(\d+)", r"rf_detr.model.enc_output_norm.\1"),
+ WeightRenaming(r"transformer.enc_out_class_embed.(\d+)", r"rf_detr.model.enc_out_class_embed.\1"),
+ WeightRenaming(r"refpoint_embed.weight", r"rf_detr.model.reference_point_embed.weight"),
+ ]
+ else:
+ weight_mapping = [
+ # backbone RfDetrConvEncoder
+ WeightRenaming("backbone.0.encoder.encoder", "backbone.backbone"),
+ WeightRenaming("backbone.0.projector", "backbone.projector"),
+ # RfDetrDecoder
+ WeightRenaming("transformer.decoder", "decoder"),
+ # RfDetrForObjectDetection
+ WeightRenaming("transformer.enc_out_bbox_embed", "enc_out_bbox_embed"),
+ WeightRenaming(r"transformer.enc_output.(\d+)", r"enc_output.\1"),
+ WeightRenaming(r"transformer.enc_output_norm.(\d+)", r"enc_output_norm.\1"),
+ WeightRenaming(r"transformer.enc_out_class_embed.(\d+)", r"enc_out_class_embed.\1"),
+ WeightRenaming(r"refpoint_embed.weight", r"reference_point_embed.weight"),
+ ]
+
+ weight_mapping.extend(
+ [
+ # RfDetrConvEncoder
+ ## RfDetrMultiScaleProjector
+ WeightRenaming(r"projector.stages_sampling.(\d+)", r"projector.scale_layers.\1.sampling_layers"),
+ WeightRenaming(r"projector.stages.(\d+).0", r"projector.scale_layers.\1.projector_layer"),
+ WeightRenaming(r"projector.stages.(\d+).1", r"projector.scale_layers.\1.layer_norm"),
+ ## RfDetrSamplingLayer
+ WeightRenaming(r"sampling_layers.(\d+)", r"sampling_layers.\1.layers"),
+ WeightRenaming(r"layers.(\d+).bn", r"layers.\1.norm"),
+ ### RfDetrC2FLayer
+ WeightRenaming(r"projector_layer.cv1.conv", r"projector_layer.conv1.conv"),
+ WeightRenaming(r"projector_layer.cv1.bn", r"projector_layer.conv1.norm"),
+ WeightRenaming(r"projector_layer.cv2.conv", r"projector_layer.conv2.conv"),
+ WeightRenaming(r"projector_layer.cv2.bn", r"projector_layer.conv2.norm"),
+ WeightRenaming(r"projector_layer.m.(\d+)", r"projector_layer.bottlenecks.\1"),
+ #### RfDetrRepVggBlock
+ WeightRenaming(r"bottlenecks.(\d+).cv1.conv", r"bottlenecks.\1.conv1.conv"),
+ WeightRenaming(r"bottlenecks.(\d+).cv1.bn", r"bottlenecks.\1.conv1.norm"),
+ WeightRenaming(r"bottlenecks.(\d+).cv2.conv", r"bottlenecks.\1.conv2.conv"),
+ WeightRenaming(r"bottlenecks.(\d+).cv2.bn", r"bottlenecks.\1.conv2.norm"),
+ # RfDetrDecoder
+ ## RfDetrDecoderLayer
+ WeightRenaming(r"decoder.layers.(\d+).norm1", r"decoder.layers.\1.self_attn_layer_norm"),
+ WeightRenaming(r"decoder.layers.(\d+).norm2", r"decoder.layers.\1.cross_attn_layer_norm"),
+ WeightRenaming(r"decoder.layers.(\d+).linear1", r"decoder.layers.\1.mlp.fc1"),
+ WeightRenaming(r"decoder.layers.(\d+).linear2", r"decoder.layers.\1.mlp.fc2"),
+ WeightRenaming(r"decoder.layers.(\d+).norm3", r"decoder.layers.\1.layer_norm"),
+ WeightRenaming("decoder.norm", r"decoder.layernorm"),
+ ### RfDetrAttention
+ WeightRenaming(r"self_attn.out_proj", r"self_attn.o_proj"),
+ WeightConverter(
+ r"self_attn.in_proj_bias",
+ [r"self_attn.q_proj.bias", r"self_attn.k_proj.bias", r"self_attn.v_proj.bias"],
+ operations=[Chunk(dim=0)],
+ ),
+ WeightConverter(
+ r"self_attn.in_proj_weight",
+ [r"self_attn.q_proj.weight", r"self_attn.k_proj.weight", r"self_attn.v_proj.weight"],
+ operations=[Chunk(dim=0)],
+ ),
+ ]
+ )
+
+ # Indices depend on the value of projector_scale_factors
+ for i, scale in enumerate(rf_detr_config.projector_scale_factors):
+ if scale == 2.0:
+ weight_mapping.append(
+ WeightRenaming(
+ rf"projector.stages_sampling.{i}.(\d+).(\d+)",
+ rf"projector.scale_layers.{i}.sampling_layers.\1.layers.\2",
+ )
+ )
+ elif scale == 0.5:
+ weight_mapping.append(
+ WeightRenaming(
+ rf"projector.stages_sampling.{i}.(\d+).(\d+).conv.weight",
+ rf"projector.scale_layers.{i}.sampling_layers.\1.layers.\2.conv.weight",
+ )
+ )
+ weight_mapping.append(
+ WeightRenaming(
+ rf"projector.stages_sampling.{i}.(\d+).(\d+).bn",
+ rf"projector.scale_layers.{i}.sampling_layers.\1.layers.\2.norm",
+ )
+ )
+
+ if is_segmentation:
+ weight_mapping.extend(
+ [
+ # RfDetrForObjectDetection
+ WeightRenaming(r"bbox_embed.layers", "rf_detr.bbox_embed.layers"),
+ WeightRenaming(r"class_embed.(weight|bias)", r"rf_detr.class_embed.\1"),
+ WeightRenaming(r"query_feat.(weight|bias)", r"rf_detr.model.query_feat.\1"),
+ # Segmentation head
+ WeightRenaming(r"segmentation_head.blocks", r"blocks"),
+ WeightRenaming("segmentation_head.spatial_features_proj", "spatial_features_proj"),
+ WeightRenaming("segmentation_head.query_features_block", "query_features_block"),
+ WeightRenaming("segmentation_head.query_features_proj", "query_features_proj"),
+ WeightRenaming("segmentation_head.bias", "bias"),
+ ## RfDetrSegmentationMLPBlock
+ WeightRenaming("query_features_block.layers.0", "query_features_block.in_linear"),
+ WeightRenaming("query_features_block.layers.2", "query_features_block.out_linear"),
+ ## list[RfDetrSegmentationBlock]
+ WeightRenaming(r"blocks.(\d+)", r"blocks.\1"),
+ WeightRenaming(r"blocks.(\d+).norm", r"blocks.\1.layernorm"),
+ ]
+ )
+ return weight_mapping
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+
+ return im
+
+
+def original_preprocess_image(image, size):
+ transform = transforms.Compose(
+ [
+ transforms.Resize(list(size.values())),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image = transform(image)
+ return image
+
+
+def test_models_outputs(
+ model: RfDetrForObjectDetection | RfDetrForInstanceSegmentation,
+ image_processor: DetrImageProcessor,
+ model_name: str,
+):
+ expected_outputs = {
+ "rf-detr-nano": {
+ "logits": [-6.68004, -5.66107, -11.70373, -8.32324, -5.76176],
+ "boxes": [0.25828, 0.54991, 0.47220, 0.87432, 0.55099],
+ "loss": 14.893259,
+ },
+ "rf-detr-small": {
+ "logits": [-6.83893, -4.55097, -10.53040, -8.20657, -5.55314],
+ "boxes": [0.25782, 0.55037, 0.47922, 0.87102, 0.77074],
+ "loss": 19.771887,
+ },
+ "rf-detr-base": {
+ "logits": [-7.60410, -4.65943, -10.03144, -5.63881, -9.88291],
+ "boxes": [0.25465, 0.54864, 0.48583, 0.86991, 0.16926],
+ "loss": 21.967346,
+ },
+ "rf-detr-base-2": {
+ "logits": [-6.81648, -6.80946, -7.72004, -6.06710, -10.37419],
+ "boxes": [0.16911, 0.19784, 0.21076, 0.09273, 0.25263],
+ "loss": 21.532478,
+ },
+ "rf-detr-medium": {
+ "logits": [-6.58581, -8.07883, -12.52392, -7.78248, -10.47323],
+ "boxes": [0.16824, 0.19932, 0.21110, 0.09385, 0.77087],
+ "loss": 26.337656,
+ },
+ "rf-detr-large-deprecated": {
+ "logits": [-7.60887, -4.36907, -4.98866, -8.06598, -5.52969],
+ "boxes": [0.25576, 0.55051, 0.47765, 0.87141, 0.76966],
+ "loss": 22.116581,
+ },
+ "rf-detr-large": {
+ "logits": [-6.38973, -8.19355, -12.09174, -7.80438, -10.15835],
+ "boxes": [0.16901, 0.19936, 0.21087, 0.09311, 0.77199],
+ "loss": 27.111633,
+ },
+ "rf-detr-seg-preview": {
+ "logits": [-7.05877, -4.23362, -6.54288, -8.13384, -6.36838],
+ "boxes": [0.25603, 0.55164, 0.48340, 0.87798, 0.73861],
+ "pred_masks": [-16.72129, -16.17153, -17.06426, -17.29409, -17.78559],
+ "loss": 76.156754,
+ },
+ "rf-detr-seg-nano": {
+ "logits": [-7.38427, -5.59449, -9.97889, -11.03668, -8.62285],
+ "boxes": [0.25230, 0.54825, 0.48196, 0.86925, 0.77119],
+ "pred_masks": [-12.01641, -12.37785, -13.37312, -13.54168, -13.53435],
+ "loss": 92.973061,
+ },
+ "rf-detr-seg-small": {
+ "logits": [-7.35031, -5.09690, -9.58117, -10.80274, -8.35001],
+ "boxes": [0.25607, 0.54820, 0.48018, 0.87013, 0.90797],
+ "pred_masks": [-13.17243, -13.12057, -13.92742, -13.89896, -13.72802],
+ "loss": 87.512894,
+ },
+ "rf-detr-seg-medium": {
+ "logits": [-7.48751, -5.21394, -9.35906, -9.31897, -8.08021],
+ "boxes": [0.76891, 0.41680, 0.46182, 0.72004, 0.16810],
+ "pred_masks": [-15.67913, -17.05902, -16.72426, -17.19833, -17.18960],
+ "loss": 95.562599,
+ },
+ "rf-detr-seg-large": {
+ "logits": [-7.37005, -5.04871, -9.19777, -9.37870, -7.96562],
+ "boxes": [0.76796, 0.41489, 0.46220, 0.72197, 0.25254],
+ "pred_masks": [-15.13846, -16.88754, -16.55486, -17.23686, -17.40160],
+ "loss": 91.781540,
+ },
+ "rf-detr-seg-xlarge": {
+ "logits": [-7.42486, -4.72502, -8.16429, -8.30500, -7.21668],
+ "boxes": [0.76863, 0.41618, 0.46055, 0.72461, 0.16735],
+ "pred_masks": [-15.15330, -17.61085, -16.79776, -17.33611, -17.39120],
+ "loss": 105.279922,
+ },
+ "rf-detr-seg-xxlarge": {
+ "logits": [-7.33242, -5.11294, -6.31125, -7.06520, -7.07922],
+ "boxes": [0.25516, 0.53685, 0.49769, 0.88601, 0.76872],
+ "pred_masks": [-7.94849, -8.57010, -8.60056, -8.92854, -8.99288],
+ "loss": 99.136574,
+ },
+ }
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ image = prepare_img()
+ # Fake annotation for testing
+ annotations = {
+ "image_id": 0,
+ "annotations": [
+ {
+ "bbox": [250, 250, 350, 350],
+ "category_id": 0,
+ "iscrowd": 0,
+ "area": 122500,
+ "segments": [[0, 0, 0, 100, 100, 100, 100, 0]],
+ }
+ ],
+ }
+ is_segmentation = "segmentation" in model_name
+
+ original_pixel_values = original_preprocess_image(image, image_processor.size).unsqueeze(0).to(device)
+ inputs = image_processor(images=image, annotations=annotations, return_tensors="pt").to(device)
+ inputs["labels"][0]["masks"] = torch.zeros((1, original_pixel_values.shape[-1], original_pixel_values.shape[-2]))
+ torch.testing.assert_close(original_pixel_values, inputs["pixel_values"], atol=1e-6, rtol=1e-6)
+ print("Preprocessing looks ok!")
+
+ model.to(device)
+ model.eval()
+ model.config._attn_implementation = "eager"
+ outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
+
+ predicted_logits = outputs.logits.flatten()[:5]
+ expected_logits = expected_outputs[model_name]["logits"]
+ predicted_boxes = outputs.pred_boxes.flatten()[:5]
+ expected_boxes = expected_outputs[model_name]["boxes"]
+ torch.testing.assert_close(predicted_logits, torch.Tensor(expected_logits), rtol=5e-3, atol=5e-3)
+ torch.testing.assert_close(predicted_boxes, torch.Tensor(expected_boxes), rtol=5e-3, atol=5e-3)
+
+ if is_segmentation:
+ predicted_mask_logits = outputs.pred_masks.flatten()[:5]
+ expected_mask_logits = expected_outputs[model_name]["pred_masks"]
+ torch.testing.assert_close(predicted_mask_logits, torch.Tensor(expected_mask_logits), rtol=5e-3, atol=5e-3)
+
+ predicted_loss = outputs.loss
+ expected_loss = expected_outputs[model_name]["loss"]
+ torch.testing.assert_close(predicted_loss, torch.tensor(expected_loss), rtol=5e-3, atol=5e-3)
+
+ print("Forward pass looks ok!")
+
+
+@torch.no_grad()
+def convert_rf_detr_checkpoint(
+ model_name: str,
+ checkpoint_url: str,
+ pytorch_dump_folder_path: str,
+ push_to_hub: bool = False,
+ organization: str = "stevenbucaille",
+):
+ """
+ Convert a RF-DETR checkpoint to HuggingFace format.
+
+ Args:
+ model_name: Name of the model (e.g., "lwdetr_tiny_30e_objects365")
+ checkpoint_path: Path to the checkpoint file
+ pytorch_dump_folder_path: Path to save the converted model
+ push_to_hub: Whether to push the model to the hub
+ organization: Organization to push the model to
+ """
+ print(f"Converting {model_name} checkpoint...")
+
+ # Get model configuration
+ config, image_processor_config = get_model_config(model_name)
+ rf_detr_config = RfDetrConfig(**config)
+
+ # Load checkpoint
+ checkpoint_url = checkpoint_url if checkpoint_url is not None else HOSTED_MODELS[model_name]
+ print(f"Loading checkpoint from {checkpoint_url}...")
+ checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", weights_only=False)
+ # Create model and load weights
+ print("Creating model and loading weights...")
+ is_segmentation = "seg" in model_name
+ if is_segmentation:
+ model = RfDetrForInstanceSegmentation(rf_detr_config)
+ else:
+ model = RfDetrForObjectDetection(rf_detr_config)
+
+ weight_mapping = get_weight_mapping(rf_detr_config, is_segmentation)
+
+ # Handle different checkpoint formats
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ elif "model" in checkpoint:
+ state_dict = checkpoint["model"]
+ else:
+ state_dict = checkpoint
+
+ missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model(
+ model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None
+ )
+ print("Checkpoint loaded...")
+ if len(mismatch) > 0 or len(unexpected) > 0 or len(mismatch) > 0:
+ print("MISSING:", len(missing))
+ print("\n".join(sorted(missing)))
+ print("UNEXPECTED:", len(unexpected))
+ print("\n".join(sorted(unexpected)))
+ print("MISMATCH:", len(mismatch))
+ print(mismatch)
+
+ image_processor = DetrImageProcessor(**image_processor_config, use_fast=True)
+ test_models_outputs(model, image_processor, model_name)
+
+ repo_id = f"{organization}/{model_name}"
+ # Save model
+ print("Saving model..." + " and pushing to hub..." if push_to_hub else "")
+ model.save_pretrained(
+ pytorch_dump_folder_path,
+ save_original_format=False,
+ push_to_hub=push_to_hub,
+ repo_id=repo_id,
+ commit_message=f"Add {model_name} model",
+ )
+
+ # Save image processor
+ print("Saving image processor..." + " and pushing to hub..." if push_to_hub else "")
+ image_processor.save_pretrained(
+ pytorch_dump_folder_path,
+ push_to_hub=push_to_hub,
+ repo_id=repo_id,
+ commit_message=f"Add {model_name} image processor",
+ )
+
+ # Save config
+ print("Saving config..." + " and pushing to hub..." if push_to_hub else "")
+ rf_detr_config.save_pretrained(
+ pytorch_dump_folder_path, push_to_hub=push_to_hub, repo_id=repo_id, commit_message=f"Add {model_name} config"
+ )
+
+ if push_to_hub:
+ print("Pushed model to hub successfully!")
+
+ print(f"Conversion completed successfully for {model_name}!")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ required=True,
+ choices=list(HOSTED_MODELS.keys()),
+ help="Name of the model to convert",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory"
+ )
+ parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint file (if not using hub download)")
+ parser.add_argument("--push_to_hub", action="store_true", help="Push model to the hub")
+ parser.add_argument("--organization", type=str, default="stevenbucaille", help="Organization to push the model to")
+
+ args = parser.parse_args()
+
+ # Get checkpoint path
+ checkpoint_path = args.checkpoint_path
+
+ # Convert checkpoint
+ convert_rf_detr_checkpoint(
+ model_name=args.model_name,
+ checkpoint_url=checkpoint_path,
+ pytorch_dump_folder_path=args.pytorch_dump_folder_path,
+ push_to_hub=args.push_to_hub,
+ organization=args.organization,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/rf_detr/modeling_rf_detr.py b/src/transformers/models/rf_detr/modeling_rf_detr.py
new file mode 100644
index 000000000000..c350ec634e73
--- /dev/null
+++ b/src/transformers/models/rf_detr/modeling_rf_detr.py
@@ -0,0 +1,2133 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/rf_detr/modular_rf_detr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_rf_detr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections.abc
+import math
+import warnings
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from ... import initialization as init
+from ...activations import ACT2CLS, ACT2FN
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BackboneOutput, BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import meshgrid
+from ...utils import auto_docstring, torch_compilable_check, torch_int
+from ...utils.backbone_utils import BackboneMixin
+from ...utils.generic import ModelOutput, TransformersKwargs, can_return_tuple, check_model_inputs
+from .configuration_rf_detr import RfDetrConfig, RfDetrDinov2Config
+
+
+class RfDetrDinov2PatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+def window_partition(
+ embeddings: torch.Tensor, num_windows: int, patch_size: int, height: int, width: int
+) -> torch.Tensor:
+ batch_size = embeddings.shape[0]
+ num_h_patches = height // patch_size
+ num_w_patches = width // patch_size
+ cls_token_with_pos_embed = embeddings[:, :1]
+ pixel_tokens_with_pos_embed = embeddings[:, 1:]
+ pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(batch_size, num_h_patches, num_w_patches, -1)
+ num_w_patches_per_window = num_w_patches // num_windows
+ num_h_patches_per_window = num_h_patches // num_windows
+ windowed_pixel_tokens = pixel_tokens_with_pos_embed.view(
+ batch_size, num_windows, num_h_patches_per_window, num_windows, num_h_patches_per_window, -1
+ )
+ windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 1, 3, 2, 4, 5)
+ windowed_pixel_tokens = windowed_pixel_tokens.reshape(
+ batch_size * num_windows**2, num_h_patches_per_window * num_w_patches_per_window, -1
+ )
+ windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows**2, 1, 1)
+ embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1)
+ return embeddings
+
+
+class RfDetrDinov2Embeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: RfDetrDinov2Config) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ if config.use_mask_token:
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.patch_embeddings = RfDetrDinov2PatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.use_mask_token = config.use_mask_token
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ target_dtype = patch_pos_embed.dtype
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(torch.float32),
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ if self.config.num_windows > 1:
+ # reshape for windows
+ embeddings = window_partition(embeddings, self.config.num_windows, self.config.patch_size, height, width)
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float | None = None,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class RfDetrDinov2SelfAttention(nn.Module):
+ def __init__(self, config: RfDetrDinov2Config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.num_key_value_groups = 1
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.shape[0]
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ None,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+class RfDetrDinov2SelfOutput(nn.Module):
+ """
+ The residual connection is defined in RfDetrDinov2Layer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: RfDetrDinov2Config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class RfDetrDinov2Attention(nn.Module):
+ def __init__(self, config: RfDetrDinov2Config):
+ super().__init__()
+ self.attention = RfDetrDinov2SelfAttention(config)
+ self.output = RfDetrDinov2SelfOutput(config)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ self_attn_output, _ = self.attention(hidden_states)
+ output = self.output(self_attn_output, hidden_states)
+ return output
+
+
+class RfDetrDinov2LayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class RfDetrDinov2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: float | None = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class RfDetrDinov2MLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class RfDetrDinov2SwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+def window_unpartition_before_attention(hidden_states: torch.Tensor, num_windows: int) -> torch.Tensor:
+ batch_size, seq_len, channels = hidden_states.shape
+ num_windows_squared = num_windows**2
+ hidden_states = hidden_states.view(batch_size // num_windows_squared, num_windows_squared * seq_len, channels)
+ return hidden_states
+
+
+def window_partition_after_attention(
+ hidden_states: torch.Tensor, self_attention_output: torch.Tensor, num_windows: int
+) -> torch.Tensor:
+ batch_size, seq_len, channels = hidden_states.shape
+ num_windows_squared = num_windows**2
+ self_attention_output = self_attention_output.view(
+ batch_size * num_windows_squared, seq_len // num_windows_squared, channels
+ )
+ return self_attention_output
+
+
+class RfDetrDinov2Layer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: RfDetrDinov2Config, layer_idx: int) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = RfDetrDinov2Attention(config)
+ self.layer_scale1 = RfDetrDinov2LayerScale(config)
+ self.drop_path = RfDetrDinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_swiglu_ffn:
+ self.mlp = RfDetrDinov2SwiGLUFFN(config)
+ else:
+ self.mlp = RfDetrDinov2MLP(config)
+ self.layer_scale2 = RfDetrDinov2LayerScale(config)
+ self.num_windows = config.num_windows
+ self.global_attention = layer_idx not in config.window_block_indexes
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
+ shortcut = hidden_states
+ if self.global_attention:
+ hidden_states = window_unpartition_before_attention(hidden_states, self.num_windows)
+
+ hidden_states_norm = self.norm1(hidden_states)
+ self_attention_output = self.attention(hidden_states_norm)
+
+ if self.global_attention:
+ self_attention_output = window_partition_after_attention(
+ hidden_states, self_attention_output, self.num_windows
+ )
+
+ self_attention_output = self.layer_scale1(self_attention_output)
+
+ # first residual connection
+ hidden_states = self.drop_path(self_attention_output) + shortcut
+
+ # in Dinov2, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ return layer_output
+
+
+class RfDetrDinov2Encoder(nn.Module):
+ def __init__(self, config: RfDetrDinov2Config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([RfDetrDinov2Layer(config, i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor, output_hidden_states: bool = False) -> BaseModelOutput:
+ all_hidden_states = [hidden_states] if output_hidden_states else None
+ for i, layer_module in enumerate(self.layer):
+ hidden_states = layer_module(hidden_states)
+ if all_hidden_states:
+ all_hidden_states.append(hidden_states)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
+ )
+
+
+@auto_docstring
+class RfDetrDinov2PreTrainedModel(PreTrainedModel):
+ config: RfDetrDinov2Config
+ base_model_prefix = "rf_detr_dinov2"
+ main_input_name = "pixel_values"
+ input_modalities = ("image",)
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["RfDetrDinov2Layer"]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "attentions": RfDetrDinov2SelfAttention,
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ init.zeros_(module.bias)
+ init.ones_(module.weight)
+ elif isinstance(module, RfDetrDinov2Embeddings):
+ init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
+ init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
+ if self.config.use_mask_token:
+ init.zeros_(module.mask_token)
+ elif isinstance(module, RfDetrDinov2LayerScale):
+ init.constant_(module.lambda1, self.config.layerscale_value)
+
+
+def window_unpartition(
+ hidden_state: torch.Tensor,
+ num_windows: int,
+ num_h_patches: int,
+ num_w_patches: int,
+) -> torch.Tensor:
+ hidden_batch_size, seq_len, channels = hidden_state.shape
+ num_windows_squared = num_windows**2
+ num_h_patches_per_window = num_h_patches // num_windows
+ num_w_patches_per_window = num_w_patches // num_windows
+ hidden_state = hidden_state.reshape(
+ hidden_batch_size // num_windows_squared, num_windows_squared * seq_len, channels
+ )
+ hidden_state = hidden_state.view(
+ hidden_batch_size // num_windows_squared,
+ num_windows,
+ num_windows,
+ num_h_patches_per_window,
+ num_w_patches_per_window,
+ channels,
+ )
+ hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5)
+ return hidden_state
+
+
+@auto_docstring(
+ custom_intro="""
+ RfDetrDinov2 backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class RfDetrDinov2Backbone(RfDetrDinov2PreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = RfDetrDinov2Embeddings(config)
+ self.encoder = RfDetrDinov2Encoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> RfDetrDinov2PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: bool | None = None,
+ **kwargs,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+
+ output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = output.hidden_states
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, 1:]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+
+ num_h_patches = height // patch_size
+ num_w_patches = width // patch_size
+
+ if self.config.num_windows > 1:
+ hidden_state = window_unpartition(
+ hidden_state, self.config.num_windows, num_h_patches, num_w_patches
+ )
+
+ hidden_state = hidden_state.reshape(batch_size, num_h_patches, num_w_patches, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+
+ feature_maps += (hidden_state,)
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+class RfDetrLayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+class RfDetrConvNormLayer(nn.Module):
+ def __init__(
+ self,
+ config: RfDetrConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ activation: str | None = None,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding=kernel_size // 2,
+ bias=False,
+ )
+ self.norm = RfDetrLayerNorm(out_channels, data_format="channels_first", eps=config.layer_norm_eps)
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
+
+ def forward(self, hidden_state):
+ hidden_state = self.conv(hidden_state)
+ hidden_state = self.norm(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class RfDetrRepVggBlock(nn.Module):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__()
+ hidden_channels = int(config.d_model * config.hidden_expansion)
+ self.conv1 = RfDetrConvNormLayer(
+ config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
+ )
+ self.conv2 = RfDetrConvNormLayer(
+ config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = self.conv1(x)
+ y = self.conv2(y)
+ return y
+
+
+class RfDetrC2FLayer(nn.Module):
+ # Inspired by RTDetrCSPRepLayer
+ def __init__(self, config: RfDetrConfig, in_channels: int):
+ super().__init__()
+ num_blocks = config.c2f_num_blocks
+ activation = config.activation_function
+ out_channels = config.d_model
+
+ self.hidden_channels = int(out_channels * config.hidden_expansion)
+
+ conv1_out_channels = 2 * self.hidden_channels
+ self.conv1 = RfDetrConvNormLayer(config, in_channels, conv1_out_channels, 1, 1, activation=activation)
+
+ conv2_in_channels = (2 + num_blocks) * self.hidden_channels
+ self.conv2 = RfDetrConvNormLayer(config, conv2_in_channels, out_channels, 1, 1, activation=activation)
+
+ self.bottlenecks = nn.ModuleList(RfDetrRepVggBlock(config) for _ in range(num_blocks))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv1(hidden_states)
+ all_hidden_states = list(hidden_states.split(self.hidden_channels, 1))
+ hidden_states = all_hidden_states[-1]
+
+ for bottleneck in self.bottlenecks:
+ hidden_states = bottleneck(hidden_states)
+ all_hidden_states.append(hidden_states)
+
+ hidden_states = torch.cat(all_hidden_states, 1)
+ hidden_states = self.conv2(hidden_states)
+ return hidden_states
+
+
+class RfDetrSamplingLayer(nn.Module):
+ def __init__(self, config: RfDetrConfig, channel_size: int, scale: float):
+ super().__init__()
+
+ self.scale = scale
+ self.channel_size = channel_size
+
+ layers = []
+ if scale == 2.0:
+ layers.append(nn.ConvTranspose2d(channel_size, channel_size // 2, 2, 2))
+ elif scale == 0.5:
+ layers.append(RfDetrConvNormLayer(config, channel_size, channel_size, 3, 2, activation="relu"))
+ self.layers = nn.ModuleList(layers)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class RfDetrScaleProjector(nn.Module):
+ def __init__(self, config: RfDetrConfig, scale: float):
+ super().__init__()
+
+ intermediate_dims = [config.backbone_config.hidden_size] * len(config.backbone_config.out_indices)
+ sampling_layers = []
+ for channel_size in intermediate_dims:
+ sampling_layers.append(RfDetrSamplingLayer(config, channel_size, scale))
+ self.sampling_layers = nn.ModuleList(sampling_layers)
+
+ intermediate_dim = intermediate_dims[-1]
+ if scale == 2.0:
+ intermediate_dim = intermediate_dim // 2
+ projector_input_dim = intermediate_dim * len(intermediate_dims)
+
+ self.projector_layer = RfDetrC2FLayer(config, projector_input_dim)
+ self.layer_norm = RfDetrLayerNorm(config.d_model, data_format="channels_first")
+
+ def forward(self, hidden_states_tuple: tuple[torch.Tensor]) -> torch.Tensor:
+ sampled_hidden_states = []
+ for sampling_layer, hidden_states in zip(self.sampling_layers, hidden_states_tuple):
+ hidden_states = sampling_layer(hidden_states)
+ sampled_hidden_states.append(hidden_states)
+ hidden_states = torch.cat(sampled_hidden_states, dim=1)
+ hidden_states = self.projector_layer(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ return hidden_states
+
+
+class RfDetrMultiScaleProjector(nn.Module):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__()
+
+ self.config = config
+ scale_factors = config.projector_scale_factors
+
+ self.scale_layers = nn.ModuleList([RfDetrScaleProjector(config, scale) for scale in scale_factors])
+
+ def forward(self, hidden_states: tuple[torch.Tensor]) -> list[torch.Tensor]:
+ output_hidden_states = []
+ for scale_layer in self.scale_layers:
+ output_hidden_states.append(scale_layer(hidden_states))
+ return output_hidden_states
+
+
+class RfDetrConvEncoder(nn.Module):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__()
+ self.backbone = RfDetrDinov2Backbone(config.backbone_config)
+ self.projector = RfDetrMultiScaleProjector(config)
+
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
+ # send pixel_values through the model to get list of feature maps
+ features = self.backbone(pixel_values).feature_maps
+ features = self.projector(features)
+ out = []
+ for feature_map in features:
+ # downsample pixel_mask to match shape of corresponding feature_map
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
+ out.append((feature_map, mask))
+ return out
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class RfDetrAttention(nn.Module):
+ def __init__(self, config: RfDetrConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.d_model // config.decoder_self_attention_heads)
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = False
+ self.num_key_value_groups = 1
+
+ self.q_proj = nn.Linear(
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.decoder_self_attention_heads * self.head_dim, config.d_model, bias=config.attention_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, seq_len, _ = hidden_states.shape
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ hidden_states_original = hidden_states
+ if position_embeddings is not None:
+ hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
+
+ if self.training:
+ # at training, we use group detr technique to add more supervision by using multiple weight-sharing decoders at once for faster convergence
+ # at inference, we only use one decoder
+ hidden_states_original = torch.cat(
+ hidden_states_original.split(seq_len // self.config.group_detr, dim=1), dim=0
+ )
+ hidden_states = torch.cat(hidden_states.split(seq_len // self.config.group_detr, dim=1), dim=0)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if self.training:
+ attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1)
+
+ return attn_output, attn_weights
+
+
+@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
+class MultiScaleDeformableAttention(nn.Module):
+ def forward(
+ self,
+ value: Tensor,
+ value_spatial_shapes: Tensor,
+ value_spatial_shapes_list: list[tuple],
+ level_start_index: Tensor,
+ sampling_locations: Tensor,
+ attention_weights: Tensor,
+ im2col_step: int,
+ ):
+ batch_size, _, num_heads, hidden_dim = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
+ # batch_size, height*width, num_heads, hidden_dim
+ # -> batch_size, height*width, num_heads*hidden_dim
+ # -> batch_size, num_heads*hidden_dim, height*width
+ # -> batch_size*num_heads, hidden_dim, height, width
+ value_l_ = (
+ value_list[level_id]
+ .flatten(2)
+ .transpose(1, 2)
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
+ )
+ # batch_size, num_queries, num_heads, num_points, 2
+ # -> batch_size, num_heads, num_queries, num_points, 2
+ # -> batch_size*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
+ sampling_value_l_ = nn.functional.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
+ )
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(batch_size, num_heads * hidden_dim, num_queries)
+ )
+ return output.transpose(1, 2).contiguous()
+
+
+class RfDetrMultiscaleDeformableAttention(nn.Module):
+ """
+ Multiscale deformable attention as proposed in Deformable DETR.
+ """
+
+ def __init__(self, config: RfDetrConfig, num_heads: int, n_points: int):
+ super().__init__()
+
+ self.attn = MultiScaleDeformableAttention()
+
+ if config.d_model % num_heads != 0:
+ raise ValueError(
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
+ )
+ dim_per_head = config.d_model // num_heads
+ # check if dim_per_head is power of 2
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
+ warnings.warn(
+ "You'd better set embed_dim (d_model) in RfDetrMultiscaleDeformableAttention to make the"
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
+ " implementation."
+ )
+
+ self.im2col_step = 64
+
+ self.d_model = config.d_model
+ self.n_levels = config.num_feature_levels
+ self.n_heads = num_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
+
+ self.disable_custom_kernels = config.disable_custom_kernels
+
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
+ return tensor if position_embeddings is None else tensor + position_embeddings
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ position_embeddings: torch.Tensor | None = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if position_embeddings is not None:
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
+
+ batch_size, num_queries, _ = hidden_states.shape
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
+ total_elements = sum(height * width for height, width in spatial_shapes_list)
+ torch_compilable_check(
+ total_elements == sequence_length,
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
+ )
+
+ value = self.value_proj(encoder_hidden_states)
+ if attention_mask is not None:
+ # we invert the attention_mask
+ value = value.masked_fill(~attention_mask[..., None], float(0))
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
+ )
+ attention_weights = self.attention_weights(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
+ )
+ attention_weights = F.softmax(attention_weights, -1).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
+ )
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
+ num_coordinates = reference_points.shape[-1]
+ if num_coordinates == 2:
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif num_coordinates == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ )
+ else:
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
+
+ output = self.attn(
+ value,
+ spatial_shapes,
+ spatial_shapes_list,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+
+ output = self.output_proj(output)
+
+ return output, attention_weights
+
+
+class RfDetrMLP(nn.Module):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__()
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
+ self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class RfDetrDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: RfDetrConfig, layer_idx: int):
+ nn.Module.__init__(self)
+
+ # self-attention
+ self.self_attn = RfDetrAttention(config, layer_idx=layer_idx)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
+
+ # cross-attention
+ self.cross_attn = RfDetrMultiscaleDeformableAttention(
+ config,
+ num_heads=config.decoder_cross_attention_heads,
+ n_points=config.decoder_n_points,
+ )
+ self.cross_attn_layer_norm = nn.LayerNorm(config.d_model)
+
+ # mlp
+ self.mlp = RfDetrMLP(config)
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor | None = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ encoder_hidden_states: torch.Tensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ self_attention_output, self_attn_weights = self.self_attn(
+ hidden_states, position_embeddings=position_embeddings, **kwargs
+ )
+
+ self_attention_output = nn.functional.dropout(self_attention_output, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + self_attention_output
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ cross_attention_output, cross_attn_weights = self.cross_attn(
+ hidden_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ position_embeddings=position_embeddings,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ **kwargs,
+ )
+ cross_attention_output = nn.functional.dropout(cross_attention_output, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + cross_attention_output
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
+
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+
+ return hidden_states
+
+
+@auto_docstring
+class RfDetrPreTrainedModel(PreTrainedModel):
+ config: RfDetrConfig
+ base_model_prefix = "model"
+ main_input_name = "pixel_values"
+ _no_split_modules = [
+ r"RfDetrConvEncoder",
+ r"RfDetrDecoderLayer",
+ ]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "attentions": [RfDetrAttention, RfDetrMultiscaleDeformableAttention],
+ "hidden_states": [RfDetrDecoderLayer],
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+
+ if isinstance(module, RfDetrMultiscaleDeformableAttention):
+ init.constant_(module.sampling_offsets.weight, 0.0)
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(module.n_heads, 1, 1, 2)
+ .repeat(1, module.n_levels, module.n_points, 1)
+ )
+ for i in range(module.n_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
+ init.constant_(module.attention_weights.weight, 0.0)
+ init.constant_(module.attention_weights.bias, 0.0)
+ init.xavier_uniform_(module.value_proj.weight)
+ init.constant_(module.value_proj.bias, 0.0)
+ init.xavier_uniform_(module.output_proj.weight)
+ init.constant_(module.output_proj.bias, 0.0)
+ if hasattr(module, "level_embed"):
+ init.normal_(module.level_embed)
+ if hasattr(module, "refpoint_embed") and module.refpoint_embed is not None:
+ init.constant_(module.refpoint_embed.weight, 0)
+ if hasattr(module, "class_embed") and module.class_embed is not None:
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ init.constant_(module.class_embed.bias, bias_value)
+ if hasattr(module, "bbox_embed") and module.bbox_embed is not None:
+ init.constant_(module.bbox_embed.layers[-1].weight, 0)
+ init.constant_(module.bbox_embed.layers[-1].bias, 0)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the RfDetr backbone-decoder model.
+ """
+)
+class RfDetrModelOutput(ModelOutput):
+ r"""
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`):
+ Features from the backbone.
+ """
+
+ init_reference_points: torch.FloatTensor | None = None
+ last_hidden_state: torch.FloatTensor | None = None
+ intermediate_hidden_states: torch.FloatTensor | None = None
+ intermediate_reference_points: torch.FloatTensor | None = None
+ enc_outputs_class: torch.FloatTensor | None = None
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
+
+ backbone_features: list[torch.Tensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the RfDetrDecoder. This class adds two attributes to
+ BaseModelOutputWithCrossAttentions, namely:
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
+ - a stacked tensor of intermediate reference points.
+ """
+)
+class RfDetrDecoderOutput(ModelOutput):
+ r"""
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: torch.FloatTensor | None = None
+ intermediate_hidden_states: torch.FloatTensor | None = None
+ intermediate_reference_points: torch.FloatTensor | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ cross_attentions: tuple[torch.FloatTensor] | None = None
+
+
+# function to generate sine positional embedding for 4d coordinates
+def gen_sine_position_embeddings(pos_tensor, hidden_size=256):
+ """
+ This function computes position embeddings using sine and cosine functions from the input positional tensor,
+ which has a shape of (batch_size, num_queries, 4).
+ The last dimension of `pos_tensor` represents the following coordinates:
+ - 0: x-coord
+ - 1: y-coord
+ - 2: width
+ - 3: height
+
+ The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension
+ achieved by concatenating the sine and cosine values for each coordinate.
+ """
+ scale = 2 * math.pi
+ dim = hidden_size // 2
+ dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
+ dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
+ x_embed = pos_tensor[:, :, 0] * scale
+ y_embed = pos_tensor[:, :, 1] * scale
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
+ if pos_tensor.size(-1) == 4:
+ w_embed = pos_tensor[:, :, 2] * scale
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ h_embed = pos_tensor[:, :, 3] * scale
+ pos_h = h_embed[:, :, None] / dim_t
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+ else:
+ raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
+ return pos.to(pos_tensor.dtype)
+
+
+class RfDetrMLPPredictionHead(nn.Module):
+ """
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
+ height and width of a bounding box w.r.t. an image.
+
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+
+ """
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class RfDetrDecoder(RfDetrPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`].
+
+ The decoder updates the query embeddings through multiple self-attention and deformable cross-attention layers.
+
+ Some tweaks for RfDetr:
+
+ - it uses group detr technique at training for faster convergence.
+
+ Args:
+ config: RfDetrConfig
+ """
+
+ def __init__(self, config: RfDetrConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layers = nn.ModuleList([RfDetrDecoderLayer(config, i) for i in range(config.decoder_layers)])
+ self.layernorm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+
+ self.ref_point_head = RfDetrMLPPredictionHead(2 * config.d_model, config.d_model, config.d_model, num_layers=2)
+
+ self.post_init()
+
+ def get_reference(self, reference_points, valid_ratios):
+ # batch_size, num_queries, batch_size, 4
+ obj_center = reference_points[..., :4]
+
+ # batch_size, num_queries, num_levels, 4
+ reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+
+ # batch_size, num_queries, d_model * 2
+ query_sine_embed = gen_sine_position_embeddings(reference_points_inputs[:, :, 0, :], self.config.d_model)
+
+ # batch_size, num_queries, d_model
+ query_pos = self.ref_point_head(query_sine_embed)
+ return reference_points_inputs, query_pos
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor | None = None,
+ reference_points: torch.Tensor | None = None,
+ spatial_shapes: torch.Tensor | None = None,
+ spatial_shapes_list: torch.Tensor | None = None,
+ level_start_index: torch.Tensor | None = None,
+ valid_ratios: torch.Tensor | None = None,
+ encoder_hidden_states: torch.Tensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ intermediate = ()
+ intermediate_reference_points = (reference_points,)
+
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+
+ reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios)
+
+ for idx, decoder_layer in enumerate(self.layers):
+ hidden_states = decoder_layer(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ position_embeddings=query_pos,
+ reference_points=reference_points_inputs,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ **kwargs,
+ )
+ intermediate_hidden_states = self.layernorm(hidden_states)
+ intermediate += (intermediate_hidden_states,)
+
+ intermediate = torch.stack(intermediate)
+ last_hidden_state = intermediate[-1]
+ intermediate_reference_points = torch.stack(intermediate_reference_points)
+
+ return RfDetrDecoderOutput(
+ last_hidden_state=last_hidden_state,
+ intermediate_hidden_states=intermediate,
+ intermediate_reference_points=intermediate_reference_points,
+ )
+
+
+def refine_bboxes(reference_points, deltas):
+ reference_points = reference_points.to(deltas.device)
+ new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2]
+ new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:]
+ new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1)
+ return new_reference_points
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare LW Detr Model (consisting of a backbone and decoder Transformer) outputting raw
+ hidden-states without any specific head on top.
+ """
+)
+class RfDetrModel(RfDetrPreTrainedModel):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__(config)
+
+ # Create backbone + positional encoding
+ self.backbone = RfDetrConvEncoder(config)
+
+ self.group_detr = config.group_detr
+ self.num_queries = config.num_queries
+ hidden_dim = config.d_model
+ self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4)
+ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, hidden_dim)
+
+ self.decoder = RfDetrDecoder(config)
+
+ self.enc_output = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(self.group_detr)])
+ self.enc_output_norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(self.group_detr)])
+ # Should normally be None and then instantiated in the ForObjectDetection class
+ self.enc_out_bbox_embed = nn.ModuleList(
+ [RfDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(self.group_detr)]
+ )
+ self.enc_out_class_embed = nn.ModuleList(
+ [nn.Linear(config.d_model, config.num_labels) for _ in range(self.group_detr)]
+ )
+
+ self.post_init()
+
+ def freeze_backbone(self):
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
+ param.requires_grad_(False)
+
+ def unfreeze_backbone(self):
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
+ param.requires_grad_(True)
+
+ def get_valid_ratio(self, mask, dtype=torch.float32):
+ """Get the valid ratio of all feature maps."""
+
+ _, height, width = mask.shape
+ valid_height = torch.sum(mask[:, :, 0], 1)
+ valid_width = torch.sum(mask[:, 0, :], 1)
+ valid_ratio_height = valid_height.to(dtype) / height
+ valid_ratio_width = valid_width.to(dtype) / width
+ valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
+ return valid_ratio
+
+ def get_proposal_pos_embed(self, proposals):
+ """Get the position embedding of the proposals."""
+
+ num_pos_feats = self.config.d_model // 2
+ temperature = 10000
+ scale = 2 * math.pi
+
+ dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
+ # batch_size, num_queries, 4
+ proposals = proposals.sigmoid() * scale
+ # batch_size, num_queries, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
+ return pos
+
+ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
+ """Generate the encoder output proposals from encoded enc_output.
+
+ Args:
+ enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
+ padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
+ spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps.
+
+ Returns:
+ `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
+ - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
+ directly predict a bounding box. (without the need of a decoder)
+ - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse
+ sigmoid.
+ """
+ batch_size = enc_output.shape[0]
+ proposals = []
+ _cur = 0
+ for level, (height, width) in enumerate(spatial_shapes):
+ mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
+ valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = meshgrid(
+ torch.linspace(
+ 0,
+ height - 1,
+ height,
+ dtype=enc_output.dtype,
+ device=enc_output.device,
+ ),
+ torch.linspace(
+ 0,
+ width - 1,
+ width,
+ dtype=enc_output.dtype,
+ device=enc_output.device,
+ ),
+ indexing="ij",
+ )
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
+ width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
+ proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4)
+ proposals.append(proposal)
+ _cur += height * width
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
+ output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf"))
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
+
+ # assign each pixel as an object query
+ object_query = enc_output
+ object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0))
+ object_query = object_query.masked_fill(~output_proposals_valid, float(0))
+ return object_query, output_proposals
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> RfDetrModelOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DeformableDetrModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small_60e_coco")
+ >>> model = DeformableDetrModel.from_pretrained("stevenbucaille/rfdetr_small_60e_coco")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 200, 256]
+ ```"""
+ batch_size, num_channels, height, width = pixel_values.shape
+ device = pixel_values.device
+
+ if pixel_mask is None:
+ pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
+
+ # First, retrieve feature maps from backbone
+ features = self.backbone(pixel_values, pixel_mask)
+
+ sources = []
+ masks = []
+ for level, (source, mask) in enumerate(features):
+ sources.append(source)
+ masks.append(mask)
+ if mask is None:
+ raise ValueError("No attention mask was provided")
+
+ # Get initial reference points and query features
+ if self.training:
+ reference_points = self.reference_point_embed.weight
+ query_feat = self.query_feat.weight
+ else:
+ # only use first group of reference points and query features during inference
+ # reference_points (num_queries, 4) : spatial locations of the queries
+ # query_feat (num_queries, d_model) : features of the queries
+ reference_points = self.reference_point_embed.weight[: self.num_queries]
+ query_feat = self.query_feat.weight[: self.num_queries]
+
+ # Prepare decoder inputs (by flattening)
+ source_flatten = []
+ mask_flatten = []
+ spatial_shapes_list = []
+ for source, mask in zip(sources, masks):
+ batch_size, num_channels, height, width = source.shape
+ spatial_shape = (height, width)
+ spatial_shapes_list.append(spatial_shape)
+ source = source.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ source_flatten.append(source)
+ mask_flatten.append(mask)
+ # source_flatten (batch_size, sum(H*W), d_model) : flattened multi-scale feature maps
+ # mask_flatten (batch_size, sum(H*W)) : flattened mask
+ # spatial_shapes (num_levels, 2) : spatial shapes of the feature maps
+ # level_start_index (num_levels,) : start index of each level in source_flatten
+ # valid_ratios (batch_size, num_levels, 2) : valid ratios of the feature maps
+ source_flatten = torch.cat(source_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
+
+ # Duplicate query features and reference points for each image in the batch
+ target = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
+ reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Generate encoder output proposals
+ object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
+ source_flatten, ~mask_flatten, spatial_shapes_list
+ )
+
+ group_detr = self.group_detr if self.training else 1
+ topk = self.num_queries
+ topk_coords_logits = []
+ object_query_undetach = []
+
+ # Iterate over each group of object queries to refine the object queries
+ for group_id in range(group_detr):
+ group_object_query = self.enc_output[group_id](object_query_embedding)
+ group_object_query = self.enc_output_norm[group_id](group_object_query)
+
+ group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)
+ group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
+ group_enc_outputs_coord = refine_bboxes(output_proposals, group_delta_bbox)
+
+ group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1]
+ group_topk_coords_logits_undetach = torch.gather(
+ group_enc_outputs_coord,
+ 1,
+ group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
+ )
+ group_topk_coords_logits = group_topk_coords_logits_undetach.detach()
+ group_object_query_undetach = torch.gather(
+ group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.config.d_model)
+ )
+
+ topk_coords_logits.append(group_topk_coords_logits)
+ object_query_undetach.append(group_object_query_undetach)
+
+ # Concatenate the object queries and reference points from all groups
+ topk_coords_logits = torch.cat(topk_coords_logits, 1)
+ object_query_undetach = torch.cat(object_query_undetach, 1)
+
+ # Get the class and coordinate logits from the object queries
+ # enc_outputs_class (batch_size, num_queries, d_model) : object queries
+ # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of the object queries
+ enc_outputs_class = object_query_undetach
+ enc_outputs_coord_logits = topk_coords_logits
+
+ # Refine the reference points using the coordinate logits
+ two_stage_len = topk_coords_logits.shape[-2]
+ reference_points_two_stage_subset = reference_points[..., :two_stage_len, :]
+ reference_points_subset = reference_points[..., two_stage_len:, :]
+ reference_points_two_stage_subset = refine_bboxes(topk_coords_logits, reference_points_two_stage_subset)
+ reference_points = torch.cat([reference_points_two_stage_subset, reference_points_subset], dim=-2)
+ init_reference_points = reference_points
+
+ # Pass the object queries and reference points to the decoder
+ decoder_outputs = self.decoder(
+ inputs_embeds=target,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ encoder_hidden_states=source_flatten,
+ encoder_attention_mask=mask_flatten,
+ **kwargs,
+ )
+
+ # init_reference_points (batch_size, num_queries, 4) : initial reference points
+ # last_hidden_state (batch_size, num_queries, d_model) : final object queries
+ # intermediate_hidden_states (batch_size, num_decoder_layers, num_queries, d_model) : intermediate object queries
+ # intermediate_reference_points (batch_size, num_decoder_layers, num_queries, 4) : intermediate reference points
+ # backbone_features list(batch_size, num_levels, d_model, H, W) : backbone features
+ # enc_outputs_class (batch_size, num_queries, d_model) : encoder outputs object queries
+ # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of encoder object queries
+ return RfDetrModelOutput(
+ init_reference_points=init_reference_points,
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
+ backbone_features=sources,
+ enc_outputs_class=enc_outputs_class,
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`RfDetrForObjectDetection`].
+ """
+)
+class RfDetrObjectDetectionOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
+ unnormalized bounding boxes.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`):
+ Features from the backbone.
+ """
+
+ loss: torch.FloatTensor | None = None
+ loss_dict: dict | None = None
+ logits: torch.FloatTensor | None = None
+ pred_boxes: torch.FloatTensor | None = None
+ auxiliary_outputs: list[dict] | None = None
+ init_reference_points: torch.FloatTensor | None = None
+ last_hidden_state: torch.FloatTensor | None = None
+ intermediate_hidden_states: torch.FloatTensor | None = None
+ intermediate_reference_points: torch.FloatTensor | None = None
+ enc_outputs_class: Any = None
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
+
+ backbone_features: list[torch.Tensor] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ LW DETR Model (consisting of a backbone and decoder Transformer) with object detection heads on
+ top, for tasks such as COCO detection.
+ """
+)
+class RfDetrForObjectDetection(RfDetrPreTrainedModel):
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
+ # We can't initialize the model on meta device as some weights are modified during the initialization
+ _no_split_modules = None
+ _tied_weights_keys = None
+
+ def __init__(self, config: RfDetrConfig):
+ super().__init__(config)
+ self.model = RfDetrModel(config)
+ self.class_embed = nn.Linear(config.d_model, config.num_labels)
+ self.bbox_embed = RfDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
+
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor | None = None,
+ labels: list[dict] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> RfDetrObjectDetectionOutput:
+ r"""
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
+ Not used by default. Can be used to mask object queries.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
+ can choose to directly pass a flattened representation of an image.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
+ embedded representation.
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, LwDetrForObjectDetection
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/lwdetr_small_60e_coco")
+ >>> model = LwDetrForObjectDetection.from_pretrained("stevenbucaille/lwdetr_small_60e_coco")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> target_sizes = torch.tensor([image.size[::-1]])
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
+ ... 0
+ ... ]
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
+ ... f"{round(score.item(), 3)} at location {box}"
+ ... )
+ Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78]
+ Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
+ Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
+ ```"""
+ outputs = self.model(
+ pixel_values,
+ pixel_mask=pixel_mask,
+ **kwargs,
+ )
+
+ last_hidden_states = outputs.last_hidden_state
+ intermediate_reference_points = outputs.intermediate_reference_points
+ enc_outputs_class = outputs.enc_outputs_class
+ enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits
+
+ # Get logits and boxes from first stage object queries
+ enc_outputs_class_logits = self.get_encoder_outputs_class_logits(enc_outputs_class)
+
+ # Get logits and boxes from second stage object queries
+ logits = self.class_embed(last_hidden_states)
+ pred_boxes_delta = self.bbox_embed(last_hidden_states)
+ pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
+
+ loss, loss_dict, auxiliary_outputs = None, None, None
+ if labels is not None:
+ outputs_class, outputs_coord = None, None
+ if self.config.auxiliary_loss:
+ intermediate_hidden_states = outputs.intermediate_hidden_states
+ outputs_coord_delta = self.bbox_embed(intermediate_hidden_states)
+ outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta)
+ outputs_class = self.class_embed(intermediate_hidden_states)
+
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
+ logits,
+ labels,
+ self.device,
+ pred_boxes,
+ self.config,
+ outputs_class,
+ outputs_coord,
+ enc_outputs_class_logits,
+ enc_outputs_boxes_logits,
+ )
+
+ return RfDetrObjectDetectionOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=outputs.last_hidden_state,
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
+ intermediate_reference_points=outputs.intermediate_reference_points,
+ init_reference_points=outputs.init_reference_points,
+ enc_outputs_class=enc_outputs_class_logits,
+ enc_outputs_coord_logits=enc_outputs_boxes_logits,
+ backbone_features=outputs.backbone_features,
+ )
+
+ def get_encoder_outputs_class_logits(self, enc_outputs_class_logits: torch.Tensor) -> Tensor:
+ enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.config.num_queries, dim=1)
+ group_detr = self.config.group_detr if self.training else 1
+ pred_class = [
+ self.model.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index])
+ for group_index in range(group_detr)
+ ]
+ enc_outputs_class_logits = torch.cat(pred_class, dim=1)
+ return enc_outputs_class_logits
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`RfDetrForInstanceSegmentation`].
+ """
+)
+class RfDetrInstanceSegmentationOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
+ unnormalized bounding boxes.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
+ Segmentation masks logits for all queries. See also
+ [`~DetrImageProcessor.post_process_semantic_segmentation`] or
+ [`~DetrImageProcessor.post_process_instance_segmentation`]
+ [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
+ segmentation masks respectively.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_mask_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, width, height)`, *optional*):
+ Mask logits from the encoder for all queries.
+ """
+
+ loss: torch.FloatTensor | None = None
+ loss_dict: dict | None = None
+ logits: torch.FloatTensor | None = None
+ pred_boxes: torch.FloatTensor | None = None
+ pred_masks: torch.FloatTensor = None
+ auxiliary_outputs: list[dict] | None = None
+ init_reference_points: torch.FloatTensor | None = None
+ last_hidden_state: torch.FloatTensor | None = None
+ intermediate_hidden_states: torch.FloatTensor | None = None
+ intermediate_reference_points: torch.FloatTensor | None = None
+ enc_outputs_mask_logits: torch.FloatTensor | None = None
+
+
+class RfDetrSegmentationBlock(nn.Module):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch.
+
+ Args:
+ config ([`RfDetrConfig`]): Model configuration class.
+ dim (`int`): Number of input channels.
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
+ """
+
+ def __init__(self, config: RfDetrConfig):
+ super().__init__()
+ dim = config.d_model
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) # depthwise conv
+ self.layernorm = RfDetrLayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = ACT2FN[config.segmentation_head_activation_function]
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+ features = self.dwconv(features)
+ features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ features = self.layernorm(features)
+ features = self.pwconv1(features)
+ features = self.act(features)
+ features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+ features = features + residual
+ return features
+
+
+class RfDetrSegmentationMLPBlock(nn.Module):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__()
+ dim = config.d_model
+ self.norm_in = nn.LayerNorm(dim)
+ self.in_linear = nn.Linear(dim, dim * 4)
+ self.act = ACT2FN[config.segmentation_head_activation_function]
+ self.out_linear = nn.Linear(dim * 4, dim)
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+ features = self.norm_in(features)
+ features = self.in_linear(features)
+ features = self.act(features)
+ features = self.out_linear(features)
+ features = features + residual
+ return features
+
+
+class RfDetrForInstanceSegmentation(RfDetrPreTrainedModel):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__(config)
+
+ self.rf_detr = RfDetrForObjectDetection(config)
+
+ num_blocks = config.decoder_layers
+ self.downsample_ratio = config.mask_downsample_ratio
+ self.blocks = nn.ModuleList([RfDetrSegmentationBlock(config) for _ in range(num_blocks)])
+ self.spatial_features_proj = nn.Conv2d(config.d_model, config.d_model, kernel_size=1)
+
+ self.query_features_block = RfDetrSegmentationMLPBlock(config)
+ self.query_features_proj = nn.Linear(config.d_model, config.d_model)
+
+ self.bias = nn.Parameter(torch.zeros(1), requires_grad=True)
+
+ self.post_init()
+
+ def segmentation_head(self, spatial_features, query_features, image_size: torch.Size, skip_blocks: bool = False):
+ # spatial features: (B, C, H, W)
+ # query features: [(B, N, C)] for each decoder layer
+ # output: (B, N, H*r, W*r)
+ target_size = (image_size[0] // self.downsample_ratio, image_size[1] // self.downsample_ratio)
+ spatial_features = F.interpolate(spatial_features, size=target_size, mode="bilinear", align_corners=False)
+ list_mask_logits = []
+ if not skip_blocks:
+ for block, qf in zip(self.blocks, query_features):
+ spatial_features = block(spatial_features)
+ spatial_features_proj = self.spatial_features_proj(spatial_features)
+ qf = self.query_features_block(qf)
+ qf = self.query_features_proj(qf)
+ mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf)
+ mask_logits = mask_logits + self.bias
+ list_mask_logits.append(mask_logits)
+ else:
+ query_features = self.query_features_block(query_features)
+ query_features = self.query_features_proj(query_features)
+ mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features, query_features)
+ mask_logits = mask_logits + self.bias
+ list_mask_logits.append(mask_logits)
+
+ return list_mask_logits
+
+ @check_model_inputs
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor | None = None,
+ labels: list[dict] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> dict[str, torch.Tensor]:
+ image_size = pixel_values.shape[-2:]
+
+ outputs = self.rf_detr.model(
+ pixel_values,
+ pixel_mask=pixel_mask,
+ **kwargs,
+ )
+
+ spatial_features = outputs.backbone_features[-1]
+ last_hidden_states = outputs.last_hidden_state
+ intermediate_reference_points = outputs.intermediate_reference_points
+ enc_outputs_class = outputs.enc_outputs_class
+ enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits
+ query_features = outputs.intermediate_hidden_states
+ last_hidden_state = outputs.last_hidden_state
+
+ # First stage segmentation proposals
+ enc_outputs_class_logits = self.rf_detr.get_encoder_outputs_class_logits(enc_outputs_class)
+ enc_outputs_masks = self.segmentation_head(spatial_features, enc_outputs_class, image_size, skip_blocks=True)
+ enc_outputs_masks = torch.cat(enc_outputs_masks, dim=1)
+
+ # Second stage segmentation proposals
+ logits = self.rf_detr.class_embed(last_hidden_states)
+ pred_boxes_delta = self.rf_detr.bbox_embed(last_hidden_states)
+ pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
+ outputs_masks = self.segmentation_head(spatial_features, query_features, image_size)
+
+ pred_masks = outputs_masks[-1]
+
+ loss, loss_dict, auxiliary_outputs = None, None, None
+ if labels is not None:
+ outputs_class, outputs_coord = None, None
+ if self.config.auxiliary_loss:
+ intermediate_hidden_states = outputs.intermediate_hidden_states
+ outputs_coord_delta = self.rf_detr.bbox_embed(intermediate_hidden_states)
+ outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta)
+ outputs_class = self.rf_detr.class_embed(intermediate_hidden_states)
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
+ logits,
+ labels,
+ self.device,
+ pred_boxes,
+ pred_masks,
+ self.config,
+ outputs_class,
+ outputs_coord,
+ outputs_masks,
+ enc_outputs_class_logits,
+ enc_outputs_boxes_logits,
+ enc_outputs_masks,
+ )
+
+ return RfDetrInstanceSegmentationOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ pred_masks=pred_masks,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=last_hidden_state,
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
+ intermediate_reference_points=outputs.intermediate_reference_points,
+ init_reference_points=outputs.init_reference_points,
+ enc_outputs_mask_logits=enc_outputs_masks,
+ )
+
+
+__all__ = [
+ "RfDetrModel",
+ "RfDetrForObjectDetection",
+ "RfDetrForInstanceSegmentation",
+ "RfDetrPreTrainedModel",
+ "RfDetrDinov2Backbone",
+ "RfDetrDinov2PreTrainedModel",
+]
diff --git a/src/transformers/models/rf_detr/modular_rf_detr.py b/src/transformers/models/rf_detr/modular_rf_detr.py
new file mode 100644
index 000000000000..0d9489c654cb
--- /dev/null
+++ b/src/transformers/models/rf_detr/modular_rf_detr.py
@@ -0,0 +1,1297 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from ...activations import ACT2FN
+from ...configuration_utils import PreTrainedConfig
+from ...modeling_outputs import BackboneOutput, BaseModelOutput
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, logging, torch_int
+from ...utils.generic import ModelOutput, TransformersKwargs, can_return_tuple, check_model_inputs
+from ..auto import CONFIG_MAPPING
+from ..convnext.modeling_convnext import ConvNextLayer
+from ..dinov2.configuration_dinov2 import Dinov2Config
+from ..dinov2.modeling_dinov2 import (
+ Dinov2Backbone,
+ Dinov2Embeddings,
+ Dinov2Encoder,
+ Dinov2Layer,
+ Dinov2PreTrainedModel,
+ Dinov2SelfAttention,
+)
+from ..lw_detr.configuration_lw_detr import LwDetrConfig
+from ..lw_detr.modeling_lw_detr import (
+ LwDetrC2FLayer,
+ LwDetrConvEncoder,
+ LwDetrConvNormLayer,
+ LwDetrForObjectDetection,
+ LwDetrLayerNorm,
+ LwDetrModel,
+ LwDetrModelOutput,
+ LwDetrObjectDetectionOutput,
+ LwDetrPreTrainedModel,
+ LwDetrSamplingLayer,
+ LwDetrScaleProjector,
+ refine_bboxes,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class RfDetrDinov2Config(Dinov2Config):
+ r"""
+ This is the configuration class to store the configuration of a [`RfDetrDinov2Model`]. It is used to instantiate an
+ RfDetrDinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2
+ [facebook/dinov2-base](https://huggingface.co/facebook/dinov2-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+ use_mask_token (`bool`, *optional*, defaults to `True`):
+ Whether to use mask_token in embeddings.
+ num_windows (`int`, *optional*, defaults to 4):
+ Number of windows to use for windowed attention. If 1, no windowed attention is used.
+ Example:
+
+ ```python
+ >>> from transformers import RfDetrDinov2Config, RfDetrDinov2Backbone
+
+ >>> # Initializing a RfDetrDinov2 base style configuration
+ >>> configuration = RfDetrDinov2Config()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = RfDetrDinov2Backbone(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "rf_detr_dinov2"
+
+ def __init__(self, num_windows: int = 4, **super_kwargs):
+ super().__init__(**super_kwargs)
+
+ self.num_windows = num_windows
+ window_block_indexes = set(range(self._out_indices[-1] + 1))
+ window_block_indexes.difference_update(self._out_indices)
+ window_block_indexes = list(window_block_indexes)
+ self.window_block_indexes = window_block_indexes
+
+
+class RfDetrConfig(LwDetrConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`RfDetrModel`]. It is used to instantiate
+ a LW-DETR model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LW-DETR
+ [stevenbucaille/RfDetr_small_60e_coco](https://huggingface.co/stevenbucaille/RfDetr_small_60e_coco) architecture.
+
+ LW-DETR (Lightweight Detection Transformer) is a transformer-based object detection model designed for real-time
+ detection tasks. It replaces traditional CNN-based detectors like YOLO with a more efficient transformer architecture
+ that achieves competitive performance while being computationally lightweight.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`PretrainedConfig` or `dict`, *optional*):
+ The configuration of the backbone model. If not provided, will default to `RfDetrDinov2Config`
+ with a small ViT architecture optimized for detection tasks.
+ projector_scale_factors (`list[float]`, *optional*, defaults to `[]`):
+ Scale factors for the feature pyramid network. Each scale factor determines the resolution of features
+ at different levels. Supported values are 0.5, 1.0, and 2.0.
+ hidden_expansion (`float`, *optional*, defaults to 0.5):
+ Expansion factor for hidden dimensions in the projector layers.
+ c2f_num_blocks (`int`, *optional*, defaults to 3):
+ Number of blocks in the C2F layer.
+ activation_function (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function in the projector. Supported values are `"silu"`, `"relu"`, `"gelu"`.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon value for layer normalization layers.
+ d_model (`int`, *optional*, defaults to 256):
+ Dimension of the model layers and the number of expected features in the decoder inputs.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ decoder_ffn_dim (`int`, *optional*, defaults to 2048):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ decoder_n_points (`int`, *optional*, defaults to 4):
+ The number of sampled keys in each feature level for each attention head in the decoder.
+ decoder_layers (`int`, *optional*, defaults to 3):
+ Number of decoder layers in the transformer.
+ decoder_self_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the decoder self-attention.
+ decoder_cross_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the decoder cross-attention.
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function in the decoder. Supported values are `"relu"`, `"silu"`, `"gelu"`.
+ num_queries (`int`, *optional*, defaults to 300):
+ Number of object queries, i.e. detection slots. This is the maximal number of objects
+ [`RfDetrModel`] can detect in a single image.
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add bias to the attention layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ group_detr (`int`, *optional*, defaults to 13):
+ Number of groups for Group DETR attention mechanism, which helps reduce computational complexity.
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ disable_custom_kernels (`bool`, *optional*, defaults to `True`):
+ Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
+ kernels are not supported by PyTorch ONNX export.
+ class_cost (`float`, *optional*, defaults to 2):
+ Relative weight of the classification error in the Hungarian matching cost.
+ bbox_cost (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
+ giou_cost (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
+ class_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the classification loss in the Hungarian matching cost.
+ mask_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the Focal loss in the instance segmentation mask loss.
+ dice_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the DICE/F-1 loss in the object detection loss.
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 bounding box loss in the object detection loss.
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss in the object detection loss.
+ eos_coefficient (`float`, *optional*, defaults to 0.1):
+ Relative classification weight of the 'no-object' class in the object detection loss.
+ focal_alpha (`float`, *optional*, defaults to 0.25):
+ Alpha parameter in the focal loss.
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
+ mask_point_sample_ratio (`int`, *optional*, defaults to 16):
+ The ratio of points to sample for the mask loss calculation.
+ mask_downsample_ratio (`int`, *optional*, defaults to 4):
+ The downsample ratio for the segmentation masks compared to the input image resolution.
+ mask_class_loss_coefficient (`float`, *optional*, defaults to 5.0):
+ Relative weight of the Focal loss in the instance segmentation loss.
+ mask_dice_loss_coefficient (`float`, *optional*, defaults to 5.0):
+ Relative weight of the DICE/F-1 loss in the instance segmentation loss.
+ segmentation_head_activation_function (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the segmentation head. Supported values are `"relu"`, `"silu"`, `"gelu"`.
+ Examples:
+
+ ```python
+ >>> from transformers import RfDetrConfig, RfDetrModel
+
+ >>> # Initializing a LW-DETR stevenbucaille/RfDetr_small_60e_coco style configuration
+ >>> configuration = RfDetrConfig()
+
+ >>> # Initializing a model (with random weights) from the stevenbucaille/RfDetr_small_60e_coco style configuration
+ >>> model = RfDetrModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "rf_detr"
+
+ def __init__(
+ self,
+ # backbone
+ backbone_config=None,
+ # projector
+ projector_scale_factors: list[float] = [],
+ hidden_expansion=0.5,
+ c2f_num_blocks=3,
+ activation_function="silu",
+ layer_norm_eps=1e-5,
+ # decoder
+ d_model=256,
+ dropout=0.1,
+ decoder_ffn_dim=2048,
+ decoder_n_points=4,
+ decoder_layers: int = 3,
+ decoder_self_attention_heads: int = 8,
+ decoder_cross_attention_heads: int = 16,
+ decoder_activation_function="relu",
+ # model
+ num_queries=300,
+ attention_bias=True,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ group_detr: int = 13,
+ init_std=0.02,
+ disable_custom_kernels=True,
+ # loss
+ class_cost=2,
+ bbox_cost=5,
+ giou_cost=2,
+ class_loss_coefficient=1,
+ mask_loss_coefficient=1,
+ dice_loss_coefficient=1,
+ bbox_loss_coefficient=5,
+ giou_loss_coefficient=2,
+ eos_coefficient=0.1,
+ focal_alpha=0.25,
+ auxiliary_loss=True,
+ mask_point_sample_ratio=16,
+ # segmentation
+ mask_downsample_ratio=4,
+ mask_class_loss_coefficient=5.0,
+ mask_dice_loss_coefficient=5.0,
+ segmentation_head_activation_function="gelu",
+ **kwargs,
+ ):
+ self.layer_norm_eps = layer_norm_eps
+
+ # backbone
+ if backbone_config is None:
+ logger.info(
+ "`backbone_config` is `None`. Initializing the config with the default `RfDetrDinov2` backbone."
+ )
+ backbone_config = RfDetrDinov2Config(
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.0,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-06,
+ layerscale_value=1.0,
+ mlp_ratio=4,
+ num_attention_heads=6,
+ num_channels=3,
+ num_hidden_layers=12,
+ qkv_bias=True,
+ use_swiglu_ffn=False,
+ out_features=["stage2", "stage5", "stage8", "stage11"],
+ hidden_size=384,
+ patch_size=14,
+ num_windows=4,
+ num_register_tokens=0,
+ image_size=518,
+ **kwargs,
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.pop("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ self.backbone_config = backbone_config
+
+ # projector
+ self.projector_scale_factors = projector_scale_factors
+ for scale in projector_scale_factors:
+ if scale not in [0.5, 1.0, 2.0]:
+ raise ValueError(f"Unsupported scale factor: {scale}")
+ self.projector_in_channels = [d_model] * len(projector_scale_factors)
+ self.projector_out_channels = d_model
+ self.activation_function = activation_function
+ self.hidden_expansion = hidden_expansion
+ self.c2f_num_blocks = c2f_num_blocks
+ # decoder
+ self.d_model = d_model
+ self.dropout = dropout
+ self.num_queries = num_queries
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.num_feature_levels = len(self.projector_scale_factors)
+ self.decoder_n_points = decoder_n_points
+ self.decoder_layers = decoder_layers
+ self.decoder_activation_function = decoder_activation_function
+ self.decoder_self_attention_heads = decoder_self_attention_heads
+ self.decoder_cross_attention_heads = decoder_cross_attention_heads
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ # model
+ self.init_std = init_std
+ self.group_detr = group_detr
+ # Loss
+ self.auxiliary_loss = auxiliary_loss
+ # Hungarian matcher
+ self.class_cost = class_cost
+ self.bbox_cost = bbox_cost
+ self.giou_cost = giou_cost
+ # Loss coefficients
+ self.class_loss_coefficient = class_loss_coefficient
+ self.mask_loss_coefficient = mask_loss_coefficient
+ self.dice_loss_coefficient = dice_loss_coefficient
+ self.bbox_loss_coefficient = bbox_loss_coefficient
+ self.giou_loss_coefficient = giou_loss_coefficient
+ self.mask_class_loss_coefficient = mask_class_loss_coefficient
+ self.mask_dice_loss_coefficient = mask_dice_loss_coefficient
+ self.eos_coefficient = eos_coefficient
+ self.focal_alpha = focal_alpha
+ self.disable_custom_kernels = disable_custom_kernels
+ self.mask_point_sample_ratio = mask_point_sample_ratio
+ # segmentation
+ self.mask_downsample_ratio = mask_downsample_ratio
+ self.segmentation_head_activation_function = segmentation_head_activation_function
+ PreTrainedConfig.__init__(self, **kwargs)
+
+
+def window_partition(
+ embeddings: torch.Tensor, num_windows: int, patch_size: int, height: int, width: int
+) -> torch.Tensor:
+ batch_size = embeddings.shape[0]
+ num_h_patches = height // patch_size
+ num_w_patches = width // patch_size
+ cls_token_with_pos_embed = embeddings[:, :1]
+ pixel_tokens_with_pos_embed = embeddings[:, 1:]
+ pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(batch_size, num_h_patches, num_w_patches, -1)
+ num_w_patches_per_window = num_w_patches // num_windows
+ num_h_patches_per_window = num_h_patches // num_windows
+ windowed_pixel_tokens = pixel_tokens_with_pos_embed.view(
+ batch_size, num_windows, num_h_patches_per_window, num_windows, num_h_patches_per_window, -1
+ )
+ windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 1, 3, 2, 4, 5)
+ windowed_pixel_tokens = windowed_pixel_tokens.reshape(
+ batch_size * num_windows**2, num_h_patches_per_window * num_w_patches_per_window, -1
+ )
+ windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows**2, 1, 1)
+ embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1)
+ return embeddings
+
+
+class RfDetrDinov2Embeddings(Dinov2Embeddings):
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ target_dtype = patch_pos_embed.dtype
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(torch.float32),
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ if self.config.num_windows > 1:
+ # reshape for windows
+ embeddings = window_partition(embeddings, self.config.num_windows, self.config.patch_size, height, width)
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+def window_unpartition_before_attention(hidden_states: torch.Tensor, num_windows: int) -> torch.Tensor:
+ batch_size, seq_len, channels = hidden_states.shape
+ num_windows_squared = num_windows**2
+ hidden_states = hidden_states.view(batch_size // num_windows_squared, num_windows_squared * seq_len, channels)
+ return hidden_states
+
+
+def window_partition_after_attention(
+ hidden_states: torch.Tensor, self_attention_output: torch.Tensor, num_windows: int
+) -> torch.Tensor:
+ batch_size, seq_len, channels = hidden_states.shape
+ num_windows_squared = num_windows**2
+ self_attention_output = self_attention_output.view(
+ batch_size * num_windows_squared, seq_len // num_windows_squared, channels
+ )
+ return self_attention_output
+
+
+class RfDetrDinov2SelfAttention(Dinov2SelfAttention):
+ def __init__(self, config: RfDetrDinov2Config):
+ super().__init__(config)
+ self.num_key_value_groups = 1
+
+
+class RfDetrDinov2Layer(Dinov2Layer):
+ def __init__(self, config: RfDetrDinov2Config, layer_idx: int):
+ super().__init__(config)
+ self.num_windows = config.num_windows
+ self.global_attention = layer_idx not in config.window_block_indexes
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
+ shortcut = hidden_states
+ if self.global_attention:
+ hidden_states = window_unpartition_before_attention(hidden_states, self.num_windows)
+
+ hidden_states_norm = self.norm1(hidden_states)
+ self_attention_output = self.attention(hidden_states_norm)
+
+ if self.global_attention:
+ self_attention_output = window_partition_after_attention(
+ hidden_states, self_attention_output, self.num_windows
+ )
+
+ self_attention_output = self.layer_scale1(self_attention_output)
+
+ # first residual connection
+ hidden_states = self.drop_path(self_attention_output) + shortcut
+
+ # in Dinov2, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ return layer_output
+
+
+class RfDetrDinov2Encoder(Dinov2Encoder):
+ def __init__(self, config: RfDetrDinov2Config):
+ super().__init__(config)
+ self.layer = nn.ModuleList([RfDetrDinov2Layer(config, i) for i in range(config.num_hidden_layers)])
+
+
+class RfDetrDinov2PreTrainedModel(Dinov2PreTrainedModel):
+ pass
+
+
+def window_unpartition(
+ hidden_state: torch.Tensor,
+ num_windows: int,
+ num_h_patches: int,
+ num_w_patches: int,
+) -> torch.Tensor:
+ hidden_batch_size, seq_len, channels = hidden_state.shape
+ num_windows_squared = num_windows**2
+ num_h_patches_per_window = num_h_patches // num_windows
+ num_w_patches_per_window = num_w_patches // num_windows
+ hidden_state = hidden_state.reshape(
+ hidden_batch_size // num_windows_squared, num_windows_squared * seq_len, channels
+ )
+ hidden_state = hidden_state.view(
+ hidden_batch_size // num_windows_squared,
+ num_windows,
+ num_windows,
+ num_h_patches_per_window,
+ num_w_patches_per_window,
+ channels,
+ )
+ hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5)
+ return hidden_state
+
+
+class RfDetrDinov2Backbone(Dinov2Backbone):
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: bool | None = None,
+ **kwargs,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+
+ output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = output.hidden_states
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, 1:]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+
+ num_h_patches = height // patch_size
+ num_w_patches = width // patch_size
+
+ if self.config.num_windows > 1:
+ hidden_state = window_unpartition(
+ hidden_state, self.config.num_windows, num_h_patches, num_w_patches
+ )
+
+ hidden_state = hidden_state.reshape(batch_size, num_h_patches, num_w_patches, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+
+ feature_maps += (hidden_state,)
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+class RfDetrLayerNorm(LwDetrLayerNorm):
+ pass
+
+
+class RfDetrConvNormLayer(LwDetrConvNormLayer):
+ def __init__(
+ self,
+ config: RfDetrConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ activation: str | None = None,
+ ):
+ super().__init__(
+ config,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ activation,
+ )
+ self.norm = RfDetrLayerNorm(out_channels, data_format="channels_first", eps=config.layer_norm_eps)
+
+
+class RfDetrC2FLayer(LwDetrC2FLayer):
+ pass
+
+
+class RfDetrSamplingLayer(LwDetrSamplingLayer):
+ def __init__(self, config: RfDetrConfig, channel_size: int, scale: float):
+ nn.Module.__init__(self)
+
+ self.scale = scale
+ self.channel_size = channel_size
+
+ layers = []
+ if scale == 2.0:
+ layers.append(nn.ConvTranspose2d(channel_size, channel_size // 2, 2, 2))
+ elif scale == 0.5:
+ layers.append(RfDetrConvNormLayer(config, channel_size, channel_size, 3, 2, activation="relu"))
+ self.layers = nn.ModuleList(layers)
+
+
+class RfDetrScaleProjector(LwDetrScaleProjector):
+ def __init__(self, config: RfDetrConfig, scale: float):
+ nn.Module.__init__(self)
+
+ intermediate_dims = [config.backbone_config.hidden_size] * len(config.backbone_config.out_indices)
+ sampling_layers = []
+ for channel_size in intermediate_dims:
+ sampling_layers.append(RfDetrSamplingLayer(config, channel_size, scale))
+ self.sampling_layers = nn.ModuleList(sampling_layers)
+
+ intermediate_dim = intermediate_dims[-1]
+ if scale == 2.0:
+ intermediate_dim = intermediate_dim // 2
+ projector_input_dim = intermediate_dim * len(intermediate_dims)
+
+ self.projector_layer = RfDetrC2FLayer(config, projector_input_dim)
+ self.layer_norm = RfDetrLayerNorm(config.d_model, data_format="channels_first")
+
+
+class RfDetrConvEncoder(LwDetrConvEncoder):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__(config)
+ self.backbone = RfDetrDinov2Backbone(config.backbone_config)
+
+
+class RfDetrPreTrainedModel(LwDetrPreTrainedModel):
+ pass
+
+
+class RfDetrModelOutput(LwDetrModelOutput):
+ r"""
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`):
+ Features from the backbone.
+ """
+
+ backbone_features: list[torch.Tensor] = None
+
+
+class RfDetrModel(LwDetrModel):
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> RfDetrModelOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DeformableDetrModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small_60e_coco")
+ >>> model = DeformableDetrModel.from_pretrained("stevenbucaille/rfdetr_small_60e_coco")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 200, 256]
+ ```"""
+ batch_size, num_channels, height, width = pixel_values.shape
+ device = pixel_values.device
+
+ if pixel_mask is None:
+ pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
+
+ # First, retrieve feature maps from backbone
+ features = self.backbone(pixel_values, pixel_mask)
+
+ sources = []
+ masks = []
+ for level, (source, mask) in enumerate(features):
+ sources.append(source)
+ masks.append(mask)
+ if mask is None:
+ raise ValueError("No attention mask was provided")
+
+ # Get initial reference points and query features
+ if self.training:
+ reference_points = self.reference_point_embed.weight
+ query_feat = self.query_feat.weight
+ else:
+ # only use first group of reference points and query features during inference
+ # reference_points (num_queries, 4) : spatial locations of the queries
+ # query_feat (num_queries, d_model) : features of the queries
+ reference_points = self.reference_point_embed.weight[: self.num_queries]
+ query_feat = self.query_feat.weight[: self.num_queries]
+
+ # Prepare decoder inputs (by flattening)
+ source_flatten = []
+ mask_flatten = []
+ spatial_shapes_list = []
+ for source, mask in zip(sources, masks):
+ batch_size, num_channels, height, width = source.shape
+ spatial_shape = (height, width)
+ spatial_shapes_list.append(spatial_shape)
+ source = source.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ source_flatten.append(source)
+ mask_flatten.append(mask)
+ # source_flatten (batch_size, sum(H*W), d_model) : flattened multi-scale feature maps
+ # mask_flatten (batch_size, sum(H*W)) : flattened mask
+ # spatial_shapes (num_levels, 2) : spatial shapes of the feature maps
+ # level_start_index (num_levels,) : start index of each level in source_flatten
+ # valid_ratios (batch_size, num_levels, 2) : valid ratios of the feature maps
+ source_flatten = torch.cat(source_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
+
+ # Duplicate query features and reference points for each image in the batch
+ target = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
+ reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Generate encoder output proposals
+ object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
+ source_flatten, ~mask_flatten, spatial_shapes_list
+ )
+
+ group_detr = self.group_detr if self.training else 1
+ topk = self.num_queries
+ topk_coords_logits = []
+ object_query_undetach = []
+
+ # Iterate over each group of object queries to refine the object queries
+ for group_id in range(group_detr):
+ group_object_query = self.enc_output[group_id](object_query_embedding)
+ group_object_query = self.enc_output_norm[group_id](group_object_query)
+
+ group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)
+ group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
+ group_enc_outputs_coord = refine_bboxes(output_proposals, group_delta_bbox)
+
+ group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1]
+ group_topk_coords_logits_undetach = torch.gather(
+ group_enc_outputs_coord,
+ 1,
+ group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
+ )
+ group_topk_coords_logits = group_topk_coords_logits_undetach.detach()
+ group_object_query_undetach = torch.gather(
+ group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.config.d_model)
+ )
+
+ topk_coords_logits.append(group_topk_coords_logits)
+ object_query_undetach.append(group_object_query_undetach)
+
+ # Concatenate the object queries and reference points from all groups
+ topk_coords_logits = torch.cat(topk_coords_logits, 1)
+ object_query_undetach = torch.cat(object_query_undetach, 1)
+
+ # Get the class and coordinate logits from the object queries
+ # enc_outputs_class (batch_size, num_queries, d_model) : object queries
+ # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of the object queries
+ enc_outputs_class = object_query_undetach
+ enc_outputs_coord_logits = topk_coords_logits
+
+ # Refine the reference points using the coordinate logits
+ two_stage_len = topk_coords_logits.shape[-2]
+ reference_points_two_stage_subset = reference_points[..., :two_stage_len, :]
+ reference_points_subset = reference_points[..., two_stage_len:, :]
+ reference_points_two_stage_subset = refine_bboxes(topk_coords_logits, reference_points_two_stage_subset)
+ reference_points = torch.cat([reference_points_two_stage_subset, reference_points_subset], dim=-2)
+ init_reference_points = reference_points
+
+ # Pass the object queries and reference points to the decoder
+ decoder_outputs = self.decoder(
+ inputs_embeds=target,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ encoder_hidden_states=source_flatten,
+ encoder_attention_mask=mask_flatten,
+ **kwargs,
+ )
+
+ # init_reference_points (batch_size, num_queries, 4) : initial reference points
+ # last_hidden_state (batch_size, num_queries, d_model) : final object queries
+ # intermediate_hidden_states (batch_size, num_decoder_layers, num_queries, d_model) : intermediate object queries
+ # intermediate_reference_points (batch_size, num_decoder_layers, num_queries, 4) : intermediate reference points
+ # backbone_features list(batch_size, num_levels, d_model, H, W) : backbone features
+ # enc_outputs_class (batch_size, num_queries, d_model) : encoder outputs object queries
+ # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of encoder object queries
+ return RfDetrModelOutput(
+ init_reference_points=init_reference_points,
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
+ backbone_features=sources,
+ enc_outputs_class=enc_outputs_class,
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
+ )
+
+
+class RfDetrObjectDetectionOutput(LwDetrObjectDetectionOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
+ unnormalized bounding boxes.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`):
+ Features from the backbone.
+ """
+
+ backbone_features: list[torch.Tensor] = None
+
+
+class RfDetrForObjectDetection(LwDetrForObjectDetection):
+ def get_encoder_outputs_class_logits(self, enc_outputs_class_logits: torch.Tensor) -> Tensor:
+ enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.config.num_queries, dim=1)
+ group_detr = self.config.group_detr if self.training else 1
+ pred_class = [
+ self.model.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index])
+ for group_index in range(group_detr)
+ ]
+ enc_outputs_class_logits = torch.cat(pred_class, dim=1)
+ return enc_outputs_class_logits
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor | None = None,
+ labels: list[dict] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> RfDetrObjectDetectionOutput:
+ r"""
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
+ Not used by default. Can be used to mask object queries.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
+ can choose to directly pass a flattened representation of an image.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
+ embedded representation.
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, LwDetrForObjectDetection
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/lwdetr_small_60e_coco")
+ >>> model = LwDetrForObjectDetection.from_pretrained("stevenbucaille/lwdetr_small_60e_coco")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> target_sizes = torch.tensor([image.size[::-1]])
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
+ ... 0
+ ... ]
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
+ ... f"{round(score.item(), 3)} at location {box}"
+ ... )
+ Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78]
+ Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
+ Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
+ ```"""
+ outputs = self.model(
+ pixel_values,
+ pixel_mask=pixel_mask,
+ **kwargs,
+ )
+
+ last_hidden_states = outputs.last_hidden_state
+ intermediate_reference_points = outputs.intermediate_reference_points
+ enc_outputs_class = outputs.enc_outputs_class
+ enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits
+
+ # Get logits and boxes from first stage object queries
+ enc_outputs_class_logits = self.get_encoder_outputs_class_logits(enc_outputs_class)
+
+ # Get logits and boxes from second stage object queries
+ logits = self.class_embed(last_hidden_states)
+ pred_boxes_delta = self.bbox_embed(last_hidden_states)
+ pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
+
+ loss, loss_dict, auxiliary_outputs = None, None, None
+ if labels is not None:
+ outputs_class, outputs_coord = None, None
+ if self.config.auxiliary_loss:
+ intermediate_hidden_states = outputs.intermediate_hidden_states
+ outputs_coord_delta = self.bbox_embed(intermediate_hidden_states)
+ outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta)
+ outputs_class = self.class_embed(intermediate_hidden_states)
+
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
+ logits,
+ labels,
+ self.device,
+ pred_boxes,
+ self.config,
+ outputs_class,
+ outputs_coord,
+ enc_outputs_class_logits,
+ enc_outputs_boxes_logits,
+ )
+
+ return RfDetrObjectDetectionOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=outputs.last_hidden_state,
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
+ intermediate_reference_points=outputs.intermediate_reference_points,
+ init_reference_points=outputs.init_reference_points,
+ enc_outputs_class=enc_outputs_class_logits,
+ enc_outputs_coord_logits=enc_outputs_boxes_logits,
+ backbone_features=outputs.backbone_features,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`RfDetrForInstanceSegmentation`].
+ """
+)
+class RfDetrInstanceSegmentationOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
+ unnormalized bounding boxes.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
+ Segmentation masks logits for all queries. See also
+ [`~DetrImageProcessor.post_process_semantic_segmentation`] or
+ [`~DetrImageProcessor.post_process_instance_segmentation`]
+ [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
+ segmentation masks respectively.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_mask_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, width, height)`, *optional*):
+ Mask logits from the encoder for all queries.
+ """
+
+ loss: torch.FloatTensor | None = None
+ loss_dict: dict | None = None
+ logits: torch.FloatTensor | None = None
+ pred_boxes: torch.FloatTensor | None = None
+ pred_masks: torch.FloatTensor = None
+ auxiliary_outputs: list[dict] | None = None
+ init_reference_points: torch.FloatTensor | None = None
+ last_hidden_state: torch.FloatTensor | None = None
+ intermediate_hidden_states: torch.FloatTensor | None = None
+ intermediate_reference_points: torch.FloatTensor | None = None
+ enc_outputs_mask_logits: torch.FloatTensor | None = None
+
+
+class RfDetrSegmentationBlock(ConvNextLayer):
+ def __init__(self, config: RfDetrConfig):
+ dim = config.d_model
+ super().__init__(config)
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) # depthwise conv
+ self.layernorm = RfDetrLayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = ACT2FN[config.segmentation_head_activation_function]
+ del self.pwconv2
+ del self.layer_scale_parameter
+ del self.drop_path
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+ features = self.dwconv(features)
+ features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ features = self.layernorm(features)
+ features = self.pwconv1(features)
+ features = self.act(features)
+ features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+ features = features + residual
+ return features
+
+
+class RfDetrSegmentationMLPBlock(nn.Module):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__()
+ dim = config.d_model
+ self.norm_in = nn.LayerNorm(dim)
+ self.in_linear = nn.Linear(dim, dim * 4)
+ self.act = ACT2FN[config.segmentation_head_activation_function]
+ self.out_linear = nn.Linear(dim * 4, dim)
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+ features = self.norm_in(features)
+ features = self.in_linear(features)
+ features = self.act(features)
+ features = self.out_linear(features)
+ features = features + residual
+ return features
+
+
+class RfDetrForInstanceSegmentation(RfDetrPreTrainedModel):
+ def __init__(self, config: RfDetrConfig):
+ super().__init__(config)
+
+ self.rf_detr = RfDetrForObjectDetection(config)
+
+ num_blocks = config.decoder_layers
+ self.downsample_ratio = config.mask_downsample_ratio
+ self.blocks = nn.ModuleList([RfDetrSegmentationBlock(config) for _ in range(num_blocks)])
+ self.spatial_features_proj = nn.Conv2d(config.d_model, config.d_model, kernel_size=1)
+
+ self.query_features_block = RfDetrSegmentationMLPBlock(config)
+ self.query_features_proj = nn.Linear(config.d_model, config.d_model)
+
+ self.bias = nn.Parameter(torch.zeros(1), requires_grad=True)
+
+ self.post_init()
+
+ def segmentation_head(self, spatial_features, query_features, image_size: torch.Size, skip_blocks: bool = False):
+ # spatial features: (B, C, H, W)
+ # query features: [(B, N, C)] for each decoder layer
+ # output: (B, N, H*r, W*r)
+ target_size = (image_size[0] // self.downsample_ratio, image_size[1] // self.downsample_ratio)
+ spatial_features = F.interpolate(spatial_features, size=target_size, mode="bilinear", align_corners=False)
+ list_mask_logits = []
+ if not skip_blocks:
+ for block, qf in zip(self.blocks, query_features):
+ spatial_features = block(spatial_features)
+ spatial_features_proj = self.spatial_features_proj(spatial_features)
+ qf = self.query_features_block(qf)
+ qf = self.query_features_proj(qf)
+ mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf)
+ mask_logits = mask_logits + self.bias
+ list_mask_logits.append(mask_logits)
+ else:
+ query_features = self.query_features_block(query_features)
+ query_features = self.query_features_proj(query_features)
+ mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features, query_features)
+ mask_logits = mask_logits + self.bias
+ list_mask_logits.append(mask_logits)
+
+ return list_mask_logits
+
+ @check_model_inputs
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ pixel_mask: torch.LongTensor | None = None,
+ labels: list[dict] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> dict[str, torch.Tensor]:
+ image_size = pixel_values.shape[-2:]
+
+ outputs = self.rf_detr.model(
+ pixel_values,
+ pixel_mask=pixel_mask,
+ **kwargs,
+ )
+
+ spatial_features = outputs.backbone_features[-1]
+ last_hidden_states = outputs.last_hidden_state
+ intermediate_reference_points = outputs.intermediate_reference_points
+ enc_outputs_class = outputs.enc_outputs_class
+ enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits
+ query_features = outputs.intermediate_hidden_states
+ last_hidden_state = outputs.last_hidden_state
+
+ # First stage segmentation proposals
+ enc_outputs_class_logits = self.rf_detr.get_encoder_outputs_class_logits(enc_outputs_class)
+ enc_outputs_masks = self.segmentation_head(spatial_features, enc_outputs_class, image_size, skip_blocks=True)
+ enc_outputs_masks = torch.cat(enc_outputs_masks, dim=1)
+
+ # Second stage segmentation proposals
+ logits = self.rf_detr.class_embed(last_hidden_states)
+ pred_boxes_delta = self.rf_detr.bbox_embed(last_hidden_states)
+ pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
+ outputs_masks = self.segmentation_head(spatial_features, query_features, image_size)
+
+ pred_masks = outputs_masks[-1]
+
+ loss, loss_dict, auxiliary_outputs = None, None, None
+ if labels is not None:
+ outputs_class, outputs_coord = None, None
+ if self.config.auxiliary_loss:
+ intermediate_hidden_states = outputs.intermediate_hidden_states
+ outputs_coord_delta = self.rf_detr.bbox_embed(intermediate_hidden_states)
+ outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta)
+ outputs_class = self.rf_detr.class_embed(intermediate_hidden_states)
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
+ logits,
+ labels,
+ self.device,
+ pred_boxes,
+ pred_masks,
+ self.config,
+ outputs_class,
+ outputs_coord,
+ outputs_masks,
+ enc_outputs_class_logits,
+ enc_outputs_boxes_logits,
+ enc_outputs_masks,
+ )
+
+ return RfDetrInstanceSegmentationOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ pred_masks=pred_masks,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=last_hidden_state,
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
+ intermediate_reference_points=outputs.intermediate_reference_points,
+ init_reference_points=outputs.init_reference_points,
+ enc_outputs_mask_logits=enc_outputs_masks,
+ )
+
+
+__all__ = [
+ "RfDetrConfig",
+ "RfDetrModel",
+ "RfDetrForObjectDetection",
+ "RfDetrForInstanceSegmentation",
+ "RfDetrPreTrainedModel",
+ "RfDetrDinov2Config",
+ "RfDetrDinov2Backbone",
+ "RfDetrDinov2PreTrainedModel",
+]
diff --git a/tests/models/rf_detr/__init__.py b/tests/models/rf_detr/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/rf_detr/test_modeling_rf_detr.py b/tests/models/rf_detr/test_modeling_rf_detr.py
new file mode 100644
index 000000000000..b6c977fc37de
--- /dev/null
+++ b/tests/models/rf_detr/test_modeling_rf_detr.py
@@ -0,0 +1,863 @@
+# coding = utf-8
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import tempfile
+import unittest
+from functools import cached_property
+
+import numpy as np
+
+from transformers import (
+ CONFIG_NAME,
+ DetrImageProcessor,
+ RfDetrConfig,
+ RfDetrDinov2Config,
+ is_torch_available,
+ is_vision_available,
+)
+from transformers.testing_utils import (
+ Expectations,
+ require_torch,
+ require_vision,
+ slow,
+ torch_device,
+)
+
+from ...test_backbone_common import BackboneTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import RfDetrDinov2Backbone, RfDetrForInstanceSegmentation, RfDetrForObjectDetection, RfDetrModel
+
+if is_vision_available():
+ from PIL import Image
+
+CHECKPOINT = {
+ "base": "stevenbucaille/rf-detr-base",
+ "large": "stevenbucaille/rf-detr-large",
+ "segmentation": "stevenbucaille/rf-detr-segmentation",
+}
+
+
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+class RfDetrModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ is_training=True,
+ image_size=256,
+ num_labels=5,
+ n_targets=4,
+ use_labels=True,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ batch_norm_eps=1e-5,
+ # backbone
+ backbone_config=None,
+ # projector
+ projector_scale_factors=[0.5, 2.0],
+ # decoder
+ d_model=32,
+ decoder_ffn_dim=32,
+ decoder_layers=2,
+ decoder_self_attention_heads=2,
+ decoder_cross_attention_heads=4,
+ # model
+ num_queries=10,
+ group_detr=2,
+ dropout=0.0,
+ activation_dropout=0.0,
+ attention_dropout=0.0,
+ attn_implementation="eager",
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.num_channels = 3
+ self.image_size = image_size
+ self.num_labels = num_labels
+ self.n_targets = n_targets
+ self.use_labels = use_labels
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.batch_norm_eps = batch_norm_eps
+ self.backbone_config = backbone_config
+ self.projector_scale_factors = projector_scale_factors
+ self.d_model = d_model
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_self_attention_heads = decoder_self_attention_heads
+ self.decoder_cross_attention_heads = decoder_cross_attention_heads
+ self.num_queries = num_queries
+ self.group_detr = group_detr
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+ self.attention_dropout = attention_dropout
+ self.attn_implementation = attn_implementation
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
+ labels = None
+ if self.use_labels:
+ labels = []
+ for i in range(self.batch_size):
+ target = {}
+ target["class_labels"] = torch.randint(
+ high=self.num_labels, size=(self.n_targets,), device=torch_device
+ )
+ target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
+ target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
+ labels.append(target)
+
+ config = self.get_config()
+ config.num_labels = self.num_labels
+ return config, pixel_values, pixel_mask, labels
+
+ def get_config(self):
+ backbone_config = RfDetrDinov2Config(
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.0,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-06,
+ layerscale_value=1.0,
+ mlp_ratio=4,
+ num_attention_heads=2,
+ num_channels=3,
+ num_hidden_layers=4,
+ qkv_bias=True,
+ use_swiglu_ffn=False,
+ out_features=["stage2", "stage3"],
+ hidden_size=self.d_model,
+ patch_size=16,
+ num_windows=2,
+ image_size=self.image_size,
+ attn_implementation=self.attn_implementation,
+ )
+ return RfDetrConfig(
+ backbone_config=backbone_config,
+ d_model=self.d_model,
+ projector_scale_factors=self.projector_scale_factors,
+ decoder_ffn_dim=self.decoder_ffn_dim,
+ decoder_layers=self.decoder_layers,
+ decoder_self_attention_heads=self.decoder_self_attention_heads,
+ decoder_cross_attention_heads=self.decoder_cross_attention_heads,
+ num_queries=self.num_queries,
+ group_detr=self.group_detr,
+ dropout=self.dropout,
+ activation_dropout=self.activation_dropout,
+ attention_dropout=self.attention_dropout,
+ attn_implementation=self.attn_implementation,
+ _attn_implementation=self.attn_implementation,
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
+ inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
+ return config, inputs_dict
+
+ def create_and_check_rf_detr_model(self, config, pixel_values, pixel_mask, labels):
+ model = RfDetrModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
+ result = model(pixel_values)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model))
+
+ def create_and_check_rf_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
+ model = RfDetrForObjectDetection(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
+ result = model(pixel_values)
+
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
+ self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
+
+ result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
+
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
+ self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
+
+
+@require_torch
+class RfDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (RfDetrModel, RfDetrForObjectDetection, RfDetrForInstanceSegmentation) if is_torch_available() else ()
+ )
+ pipeline_model_mapping = (
+ {
+ "image-feature-extraction": RfDetrModel,
+ "object-detection": RfDetrForObjectDetection,
+ "instance-segmentation": RfDetrForInstanceSegmentation,
+ }
+ if is_torch_available()
+ else {}
+ )
+ is_encoder_decoder = False
+ test_missing_keys = False
+ test_torch_exportable = True
+
+ # special case for head models
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ if return_labels:
+ if model_class.__name__ in ["RfDetrForObjectDetection", "RfDetrForInstanceSegmentation"]:
+ labels = []
+ for i in range(self.model_tester.batch_size):
+ target = {}
+ target["class_labels"] = torch.ones(
+ size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
+ )
+ target["boxes"] = torch.ones(
+ self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
+ )
+ target["masks"] = torch.ones(
+ self.model_tester.n_targets,
+ self.model_tester.image_size,
+ self.model_tester.image_size,
+ device=torch_device,
+ dtype=torch.float,
+ )
+ labels.append(target)
+ inputs_dict["labels"] = labels
+
+ return inputs_dict
+
+ def setUp(self):
+ self.model_tester = RfDetrModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=RfDetrConfig,
+ has_text_modality=False,
+ common_properties=["d_model", "decoder_self_attention_heads"],
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_rf_detr_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_rf_detr_model(*config_and_inputs)
+
+ def test_rf_detr_object_detection_head_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_rf_detr_object_detection_head_model(*config_and_inputs)
+
+ @unittest.skip(reason="RTDetr does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="RTDetr does not use test_inputs_embeds_matches_input_ids")
+ def test_inputs_embeds_matches_input_ids(self):
+ pass
+
+ @unittest.skip(reason="RTDetr does not support input and output embeddings")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip(reason="RTDetr does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip(reason="RTDetr does not use token embeddings")
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
+
+ def test_attention_outputs(self):
+ def check_attention_outputs(inputs_dict, config, model_class):
+ model = model_class._from_config(config, attn_implementation="eager")
+ config = model.config
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.decoder_layers)
+ expected_attentions_shape = [
+ self.model_tester.batch_size,
+ self.model_tester.decoder_self_attention_heads,
+ self.model_tester.num_queries,
+ self.model_tester.num_queries,
+ ]
+ for i in range(self.model_tester.decoder_layers):
+ self.assertEqual(expected_attentions_shape, list(attentions[i].shape))
+
+ # check cross_attentions outputs
+ expected_attentions_shape = [
+ self.model_tester.batch_size,
+ self.model_tester.num_queries,
+ self.model_tester.decoder_cross_attention_heads,
+ config.num_feature_levels,
+ config.decoder_n_points,
+ ]
+ cross_attentions = outputs.cross_attentions
+ self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers)
+ for i in range(self.model_tester.decoder_layers):
+ self.assertEqual(expected_attentions_shape, list(cross_attentions[i].shape))
+
+ out_len = len(outputs)
+
+ if model_class.__name__ == "RfDetrModel":
+ correct_outlen = 9 # 7 + attentions + cross_attentions
+ if model_class.__name__ in "RfDetrForObjectDetection":
+ correct_outlen = 11 # 9 + attentions + cross_attentions
+ if "labels" in inputs_dict:
+ correct_outlen += 3 # loss, loss_dict and auxiliary outputs is added to beginning
+ elif model_class.__name__ == "RfDetrForInstanceSegmentation":
+ correct_outlen = 10 # 11 + attentions + cross_attentions
+ if "labels" in inputs_dict:
+ correct_outlen += 3 # loss, loss_dict and auxiliary outputs is added to beginning
+
+ self.assertEqual(correct_outlen, out_len)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+ inputs_dict["output_hidden_states"] = False
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ check_attention_outputs(inputs_dict, config, model_class)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ check_attention_outputs(inputs_dict, config, model_class)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_hidden_states = self.model_tester.decoder_layers + 1
+ self.assertEqual(len(hidden_states), expected_num_hidden_states)
+
+ for i in range(expected_num_hidden_states):
+ self.assertListEqual(
+ list(hidden_states[i].shape),
+ [
+ self.model_tester.batch_size,
+ self.model_tester.num_queries,
+ self.model_tester.d_model,
+ ],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = False
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ print("Model class:", model_class)
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if (
+ "level_embed" in name
+ or "sampling_offsets.bias" in name
+ or "value_proj" in name
+ or "output_proj" in name
+ or "reference_points" in name
+ or "class_embed" in name
+ or "gamma_1" in name
+ or "gamma_2" in name
+ ):
+ continue
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+
+ outputs = model(**inputs)
+
+ # we take the first output since last_hidden_state is the first item
+ output = outputs.last_hidden_state
+
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_save_load(self):
+ def check_save_load(out1, out2):
+ # make sure we don't have nans
+ out_2 = out2.cpu().numpy()
+ out_2[np.isnan(out_2)] = 0
+ out_2 = out_2[~np.isneginf(out_2)]
+
+ out_1 = out1.cpu().numpy()
+ out_1[np.isnan(out_1)] = 0
+ out_1 = out_1[~np.isneginf(out_1)]
+ max_diff = np.amax(np.abs(out_1 - out_2))
+ self.assertLessEqual(max_diff, 1e-5)
+
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ # the config file (and the generation config file, if it can generate) should be saved
+ self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
+
+ model = model_class.from_pretrained(tmpdirname)
+ model.config._attn_implementation = "eager" # TODO Have to force eager for testing, why ?
+ model.to(torch_device)
+ with torch.no_grad():
+ second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
+
+ # Save and load second time because `from_pretrained` adds a bunch of new config fields
+ # so we need to make sure those fields can be loaded back after saving
+ # Simply init as `model(config)` doesn't add those fields
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname)
+
+ if isinstance(first, tuple) and isinstance(second, tuple):
+ for tensor1, tensor2 in zip(first, second):
+ check_save_load(tensor1, tensor2)
+ else:
+ check_save_load(first, second)
+
+ def test_forward_auxiliary_loss(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.auxiliary_loss = True
+
+ # only test for object detection and segmentation model
+ for model_class in self.all_model_classes[1:]:
+ model = model_class(config)
+ model.to(torch_device)
+
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+
+ outputs = model(**inputs)
+
+ self.assertIsNotNone(outputs.auxiliary_outputs)
+ self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.decoder_layers - 1)
+
+
+@require_torch
+@require_vision
+class RfDetrModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_image_processor(self):
+ if is_vision_available():
+ return {
+ "base": DetrImageProcessor.from_pretrained(CHECKPOINT["base"]),
+ "large": DetrImageProcessor.from_pretrained(CHECKPOINT["large"]),
+ "segmentation": DetrImageProcessor.from_pretrained(CHECKPOINT["segmentation"]),
+ }
+
+ @slow
+ def test_inference_object_detection_head_base(self):
+ size = "base"
+ model = RfDetrForObjectDetection.from_pretrained(CHECKPOINT[size], attn_implementation="eager").to(
+ torch_device
+ )
+
+ image_processor = self.default_image_processor[size]
+ image = prepare_img()
+ encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
+ pixel_values = encoding["pixel_values"].to(torch_device)
+ pixel_mask = encoding["pixel_mask"].to(torch_device)
+ with torch.no_grad():
+ outputs = model(pixel_values, pixel_mask)
+
+ expected_logits_shape = torch.Size((1, model.config.num_queries, model.config.num_labels))
+ self.assertEqual(outputs.logits.shape, expected_logits_shape)
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [-7.60410, -4.65943, -10.03144, -5.63881, -9.88291],
+ }
+ )
+ expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.logits.flatten()[:5], expected_logits, rtol=2e-4, atol=2e-4)
+
+ expected_boxes_shape = torch.Size((1, model.config.num_queries, 4))
+ self.assertEqual(outputs.pred_boxes.shape, expected_boxes_shape)
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [0.25465, 0.54864, 0.48583, 0.86991, 0.16926],
+ }
+ )
+ expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(outputs.pred_boxes.flatten()[:5], expected_boxes, rtol=2e-4, atol=2e-4)
+
+ results = image_processor.post_process_object_detection(
+ outputs, threshold=0.0, target_sizes=[image.size[::-1]]
+ )[0]
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [0.9829, 0.9763, 0.9780, 0.8663],
+ }
+ )
+ expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ expected_labels = [17, 75, 17, 75]
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [
+ [7.5101, 54.5667, 318.4421, 472.1259],
+ [40.7178, 72.6109, 175.9414, 117.5903],
+ [343.0202, 23.9165, 639.3325, 372.2062],
+ [333.5370, 76.9845, 370.3848, 187.3158],
+ ],
+ }
+ )
+ expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(results["scores"][:4], expected_scores, atol=1e-3, rtol=2e-4)
+ self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels)
+ torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=2e-4)
+
+ @slow
+ def test_inference_object_detection_head_large(self):
+ size = "large"
+ model = RfDetrForObjectDetection.from_pretrained(CHECKPOINT[size], attn_implementation="eager").to(
+ torch_device
+ )
+
+ image_processor = self.default_image_processor[size]
+ image = prepare_img()
+ encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
+ pixel_values = encoding["pixel_values"].to(torch_device)
+ pixel_mask = encoding["pixel_mask"].to(torch_device)
+ with torch.no_grad():
+ outputs = model(pixel_values, pixel_mask)
+
+ expected_logits_shape = torch.Size((1, model.config.num_queries, model.config.num_labels))
+ self.assertEqual(outputs.logits.shape, expected_logits_shape)
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [-7.60888, -4.36906, -4.98865, -8.06598, -5.52970],
+ }
+ )
+ expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.logits.flatten()[:5], expected_logits, rtol=2e-3, atol=2e-3)
+
+ expected_boxes_shape = torch.Size((1, model.config.num_queries, 4))
+ self.assertEqual(outputs.pred_boxes.shape, expected_boxes_shape)
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [0.25576, 0.55051, 0.47765, 0.87141, 0.76966],
+ }
+ )
+ expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(outputs.pred_boxes.flatten()[:5], expected_boxes, rtol=2e-3, atol=2e-3)
+
+ results = image_processor.post_process_object_detection(
+ outputs, threshold=0.0, target_sizes=[image.size[::-1]]
+ )[0]
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [0.9820, 0.9874, 0.9918, 0.9696],
+ }
+ )
+ expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ expected_labels = [17, 17, 75, 75]
+
+ expectations = Expectations(
+ {
+ ("cpu", None): [
+ [10.8431, 55.1121, 316.5317, 473.3869],
+ [345.1129, 24.5076, 640.0582, 373.1581],
+ [40.3736, 73.1451, 175.8807, 117.5796],
+ [333.9091, 76.8915, 370.1848, 186.7155],
+ ],
+ }
+ )
+ expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(results["scores"][:4], expected_scores, atol=1e-3, rtol=2e-4)
+ self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels)
+ torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=2e-4)
+
+ @slow
+ def test_inference_segmentation(self):
+ size = "segmentation"
+ model = RfDetrForInstanceSegmentation.from_pretrained(CHECKPOINT[size], attn_implementation="eager").to(
+ torch_device
+ )
+
+ image_processor = self.default_image_processor[size]
+ image = prepare_img()
+ encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
+ pixel_values = encoding["pixel_values"].to(torch_device)
+ pixel_mask = encoding["pixel_mask"].to(torch_device)
+ with torch.no_grad():
+ outputs = model(pixel_values, pixel_mask)
+
+ expected_logits_shape = torch.Size((1, model.config.num_queries, model.config.num_labels))
+ self.assertEqual(outputs.logits.shape, expected_logits_shape)
+
+ score_expectations = Expectations(
+ {
+ ("cpu", None): [-7.05877, -4.23362, -6.54288, -8.13384, -6.36838],
+ }
+ )
+ expected_logits = torch.tensor(score_expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.logits.flatten()[:5], expected_logits, rtol=2e-3, atol=2e-3)
+
+ expected_boxes_shape = torch.Size((1, model.config.num_queries, 4))
+ self.assertEqual(outputs.pred_boxes.shape, expected_boxes_shape)
+
+ score_expectations = Expectations(
+ {
+ ("cpu", None): [0.25603, 0.55164, 0.48340, 0.87798, 0.73861],
+ }
+ )
+ expected_boxes = torch.tensor(score_expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(outputs.pred_boxes.flatten()[:5], expected_boxes, rtol=2e-3, atol=2e-3)
+
+ expected_masks_shape = torch.Size(
+ (
+ 1,
+ model.config.num_queries,
+ pixel_values.shape[-2] // model.config.mask_downsample_ratio,
+ pixel_values.shape[-1] // model.config.mask_downsample_ratio,
+ )
+ )
+ self.assertEqual(outputs.pred_masks.shape, expected_masks_shape)
+
+ score_expectations = Expectations(
+ {
+ ("cpu", None): [-16.72129, -16.17153, -17.06426, -17.29409, -17.78559],
+ }
+ )
+ expected_masks = torch.tensor(score_expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(outputs.pred_masks.flatten()[:5], expected_masks, rtol=2e-3, atol=2e-3)
+
+ results = image_processor.post_process_instance_segmentation(
+ outputs, threshold=0.0, target_sizes=[image.size[::-1]]
+ )[0]
+
+ expected_labels = [17, 75, 75, 17]
+ score_expectations = Expectations(
+ {
+ ("cpu", None): [0.98604, 0.985644, 0.950492, 0.967915],
+ }
+ )
+ expected_scores = torch.tensor(score_expectations.get_expectation()).to(torch.float32)
+
+ predicted_labels = [segment["label_id"] for segment in results["segments_info"]]
+ predicted_scores = torch.tensor([segment["score"] for segment in results["segments_info"]])
+
+ self.assertSequenceEqual(predicted_labels, expected_labels)
+ torch.testing.assert_close(predicted_scores, expected_scores, atol=1e-3, rtol=2e-4)
+
+
+class RfDetrDinov2ModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ mask_ratio=0.5,
+ num_windows=2,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+
+ # in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches + 1
+ self.mask_ratio = mask_ratio
+ self.num_masks = int(mask_ratio * self.seq_length)
+ self.mask_length = num_patches
+ self.num_windows = num_windows
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def get_config(self):
+ return RfDetrDinov2Config(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ num_windows=self.num_windows,
+ )
+
+ def create_and_check_backbone(self, config, pixel_values):
+ model = RfDetrDinov2Backbone(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ # verify hidden states
+ self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
+ expected_size = self.image_size // config.patch_size
+ self.parent.assertListEqual(
+ list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size]
+ )
+
+ # verify channels
+ self.parent.assertEqual(len(model.channels), len(config.out_features))
+
+ # verify backbone works with out_features=None
+ config.out_features = None
+ model = RfDetrDinov2Backbone(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ # verify feature maps
+ self.parent.assertEqual(len(result.feature_maps), 1)
+ self.parent.assertListEqual(
+ list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size]
+ )
+
+ # verify channels
+ self.parent.assertEqual(len(model.channels), 1)
+
+ # verify backbone works with apply_layernorm=False and reshape_hidden_states=False
+ config.apply_layernorm = False
+ config.reshape_hidden_states = False
+
+ model = RfDetrDinov2Backbone(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ # verify feature maps
+ self.parent.assertEqual(len(result.feature_maps), 1)
+ self.parent.assertListEqual(
+ list(result.feature_maps[0].shape), [self.batch_size, self.seq_length, self.hidden_size]
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config, pixel_values = self.prepare_config_and_inputs()
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class RfDetrDinov2BackboneTest(unittest.TestCase, BackboneTesterMixin):
+ all_model_classes = (RfDetrDinov2Backbone,) if is_torch_available() else ()
+ config_class = RfDetrDinov2Config
+
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = RfDetrDinov2ModelTester(self)
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 996e14c5e7f5..33770b210efc 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -125,6 +125,7 @@
"IdeficsPerceiverConfig": True,
"GptOssConfig": True,
"LwDetrConfig": True,
+ "RfDetrConfig": True,
}
# Common and important attributes, even if they do not always appear in the modeling files (can be a regex pattern)