diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d55a6da33a32..b93f3cc9bca2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -869,6 +869,8 @@ title: RegNet - local: model_doc/resnet title: ResNet + - local: model_doc/rf_detr + title: RF-DETR - local: model_doc/rt_detr title: RT-DETR - local: model_doc/rt_detr_v2 diff --git a/docs/source/en/model_doc/rf_detr.md b/docs/source/en/model_doc/rf_detr.md new file mode 100644 index 000000000000..0e11952c37cb --- /dev/null +++ b/docs/source/en/model_doc/rf_detr.md @@ -0,0 +1,139 @@ + +*This model was released on 2024-04-05 and added to Hugging Face Transformers on 2026-01-23.* + +
+
+ PyTorch +
+
+ +# RF-DETR + +[RF-DETR](https://huggingface.co/papers/2407.17140) proposes a Receptive Field Detection Transformer (DETR) architecture +designed to compete with and surpass the dominant YOLO series for real-time object detection. It achieves a new +state-of-the-art balance between speed (latency) and accuracy (mAP) by combining recent transformer advances with +efficient design choices. + +The RF-DETR architecture is characterized by its simple and efficient structure: a DINOv2 Backbone, a Projector, and a +shallow DETR Decoder. +It enhances the DETR architecture for efficiency and speed using the following core modifications: + +1. **DINOv2 Backbone**: Uses a powerful DINOv2 backbone for robust feature extraction. +2. **Group DETR Training**: Utilizes Group-Wise One-to-Many Assignment during training to accelerate convergence. +3. **Richer Input**: Aggregates multi-level features from the backbone and uses a C2f Projector (similarly to YOLOv8) to + pass multi-scale features. +4. **Faster Decoder**: Employs a shallow 3-layer DETR decoder with deformable cross-attention for lower latency. +5. **Optimized Queries**: Uses a mixed-query scheme combining learnable content queries and generated spatial queries. + +You can find all the available RF-DETR checkpoints under the [stevenbucaille](https://huggingface.co/stevenbucaille) +organization. +The original code can be found [here](https://github.com/roboflow/rf-detr). + +> [!TIP] +> This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). +> +> Click on the RF-DETR models in the right sidebar for more examples of how to apply RF-DETR to different object +> detection tasks. + + +The example below demonstrates how to perform object detection with the [`Pipeline`] and the [`AutoModel`] class. + + + + +```python +from transformers import pipeline +import torch + +pipeline = pipeline( + "object-detection", + model="stevenbucaille/rfdetr_small_60e_coco", + dtype=torch.float16, + device_map=0 +) + +pipeline("http://images.cocodataset.org/val2017/000000039769.jpg") +``` + + + + +```python +from transformers import AutoImageProcessor, AutoModelForObjectDetection +from PIL import Image +import requests +import torch + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = Image.open(requests.get(url, stream=True).raw) + +image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small") +model = AutoModelForObjectDetection.from_pretrained("stevenbucaille/rfdetr_small") + +# prepare image for the model +inputs = image_processor(images=image, return_tensors="pt") + +with torch.no_grad(): + outputs = model(**inputs) + +results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3) + +for result in results: + for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]): + score, label = score.item(), label_id.item() + box = [round(i, 2) for i in box.tolist()] + print(f"{model.config.id2label[label]}: {score:.2f} {box}") +``` + + + + +## Resources + + +- Scripts for finetuning [`RfDetrForObjectDetection`] with [`Trainer`] + or [Accelerate](https://huggingface.co/docs/accelerate/index) can be + found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection). +- See also: [Object detection task guide](../tasks/object_detection). + +## RfDetrConfig + +[[autodoc]] RfDetrConfig + +## RfDetrDinov2Config + +[[autodoc]] RfDetrDinov2Config + +## RfDetrModel + +[[autodoc]] RfDetrModel + - forward + +## RfDetrForObjectDetection + +[[autodoc]] RfDetrForObjectDetection + - forward + +## RfDetrForInstanceSegmentation + +[[autodoc]] RfDetrForInstanceSegmentation + - forward + +## RfDetrDinov2Backbone + +[[autodoc]] RfDetrDinov2Backbone + - forward diff --git a/src/transformers/loss/loss_lw_detr.py b/src/transformers/loss/loss_lw_detr.py index 73844fddae29..b346416b1ce1 100644 --- a/src/transformers/loss/loss_lw_detr.py +++ b/src/transformers/loss/loss_lw_detr.py @@ -343,7 +343,7 @@ def LwDetrForObjectDetectionLoss( outputs_loss["auxiliary_outputs"] = auxiliary_outputs loss_dict = criterion(outputs_loss, labels) # Fourth: compute total loss, as a weighted sum of the various losses - weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient} + weight_dict = {"loss_ce": config.class_loss_coefficient, "loss_bbox": config.bbox_loss_coefficient} weight_dict["loss_giou"] = config.giou_loss_coefficient if config.auxiliary_loss: aux_weight_dict = {} diff --git a/src/transformers/loss/loss_rf_detr.py b/src/transformers/loss/loss_rf_detr.py new file mode 100644 index 000000000000..a597f439b318 --- /dev/null +++ b/src/transformers/loss/loss_rf_detr.py @@ -0,0 +1,483 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .. import requires_backends +from ..image_transforms import center_to_corners_format +from ..utils import is_accelerate_available, is_scipy_available, is_vision_available +from .loss_for_object_detection import ( + dice_loss, + generalized_box_iou, +) +from .loss_lw_detr import LwDetrImageLoss + + +if is_vision_available(): + pass + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + + +def sigmoid_cross_entropy_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + num_masks: float, +): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + loss = loss.mean(1).sum() / num_masks + return loss + + +def point_sample(input, point_coords, **kwargs): + """ + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +def get_uncertain_point_coords_with_randomness( + coarse_logits, uncertainty_func, num_points, oversample_ratio=3, importance_sample_ratio=0.75 +): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) + point_logits = point_sample(coarse_logits, point_coords, align_corners=False) + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) + if num_random_points > 0: + point_coords = torch.cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), + ], + dim=1, + ) + return point_coords + + +def calculate_uncertainty(logits): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + assert logits.shape[1] == 1 + gt_class_logits = logits.clone() + return -(torch.abs(gt_class_logits)) + + +def batch_sigmoid_cross_entropy_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits(inputs, torch.ones_like(inputs), reduction="none") + neg = F.binary_cross_entropy_with_logits(inputs, torch.zeros_like(inputs), reduction="none") + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum("nc,mc->nm", neg, (1 - targets)) + + return loss / hw + + +def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +class RfDetrHungarianMatcher(nn.Module): + def __init__( + self, + class_cost: float = 1, + bbox_cost: float = 1, + giou_cost: float = 1, + mask_point_sample_ratio: int = 16, + cost_mask_class_cost: float = 1, + cost_mask_dice_cost: float = 1, + ): + super().__init__() + requires_backends(self, ["scipy"]) + + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: + raise ValueError("All costs of the Matcher can't be 0") + + self.mask_point_sample_ratio = mask_point_sample_ratio + self.cost_mask_class = cost_mask_class_cost + self.cost_mask_dice = cost_mask_dice_cost + + @torch.no_grad() + def forward(self, outputs, targets, group_detr): + """ + Differences: + - out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax + - class_cost uses alpha and gamma + """ + batch_size, num_queries = outputs["logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + out_masks = outputs["pred_masks"].flatten(0, 1) # [batch_size * num_queries, H, W] + + # Also concat the target labels and boxes + target_ids = torch.cat([v["class_labels"] for v in targets]) + target_bbox = torch.cat([v["boxes"] for v in targets]) + target_masks = torch.cat([v["masks"] for v in targets]) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] + + # Compute the L1 cost between boxes, cdist only supports float32 + dtype = out_bbox.dtype + out_bbox = out_bbox.to(torch.float32) + target_bbox = target_bbox.to(torch.float32) + bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) + bbox_cost = bbox_cost.to(dtype) + + # Compute the giou cost between boxes + giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) + + # Compute mask cost + num_points = out_masks.shape[-2] * out_masks.shape[-1] // self.mask_point_sample_ratio + tgt_masks = target_masks.to(out_masks.dtype) + torch.manual_seed(0) # TODO REMOVE + point_coords = torch.rand(1, num_points, 2, device=out_masks.device) + pred_masks_logits = point_sample( + out_masks.unsqueeze(1), point_coords.repeat(out_masks.shape[0], 1, 1), align_corners=False + ).squeeze(1) + tgt_masks_flat = point_sample( + tgt_masks.unsqueeze(1), point_coords.repeat(tgt_masks.shape[0], 1, 1), align_corners=False, mode="nearest" + ).squeeze(1) + + cost_mask_class = batch_sigmoid_cross_entropy_loss(pred_masks_logits, tgt_masks_flat) + cost_mask_dice = batch_dice_loss(pred_masks_logits, tgt_masks_flat) + + # Final cost matrix + cost_matrix = ( + self.bbox_cost * bbox_cost + + self.class_cost * class_cost + + self.giou_cost * giou_cost + + self.cost_mask_class * cost_mask_class + + self.cost_mask_dice * cost_mask_dice + ) + cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() + + # we assume any good match will not cause NaN or Inf, so we replace them with a large value + max_cost = cost_matrix.max() if cost_matrix.numel() > 0 else 0 + cost_matrix[cost_matrix.isinf() | cost_matrix.isnan()] = max_cost * 2 + + # Hungarian matching + sizes = [len(v["masks"]) for v in targets] + indices = [] + group_num_queries = num_queries // group_detr + cost_matrix_list = cost_matrix.split(group_num_queries, dim=1) + for group_id in range(group_detr): + group_cost_matrix = cost_matrix_list[group_id] + group_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(group_cost_matrix.split(sizes, -1))] + if group_id == 0: + indices = group_indices + else: + indices = [ + ( + np.concatenate([indice1[0], indice2[0] + group_num_queries * group_id]), + np.concatenate([indice1[1], indice2[1]]), + ) + for indice1, indice2 in zip(indices, group_indices) + ] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +class RfDetrImageLoss(LwDetrImageLoss): + def __init__(self, matcher, num_classes, focal_alpha, losses, group_detr, mask_point_sample_ratio): + nn.Module.__init__(self) + self.matcher = matcher + self.num_classes = num_classes + self.focal_alpha = focal_alpha + self.losses = losses + self.group_detr = group_detr + self.mask_point_sample_ratio = mask_point_sample_ratio + + def loss_masks(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + + Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. + """ + if "pred_masks" not in outputs: + raise KeyError("No predicted masks found in outputs") + + source_idx = self._get_source_permutation_idx(indices) + source_masks = outputs["pred_masks"] + source_masks = source_masks[source_idx] + if source_masks.numel() == 0: + return { + "loss_mask_ce": source_masks.sum(), + "loss_mask_dice": source_masks.sum(), + } + + # gather matched target masks + target_masks = torch.cat([t["masks"][j] for t, (_, j) in zip(targets, indices)], dim=0) + + source_masks = source_masks.unsqueeze(1) + target_masks = target_masks.unsqueeze(1).float() + + num_points = max( + source_masks.shape[-2], source_masks.shape[-2] * source_masks.shape[-1] // self.mask_point_sample_ratio + ) + + with torch.no_grad(): + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + source_masks, + lambda logits: calculate_uncertainty(logits), + num_points, + 3, + 0.75, + ) + # get gt labels + point_labels = point_sample( + target_masks, + point_coords, + align_corners=False, + mode="nearest", + ).squeeze(1) + + point_logits = point_sample( + source_masks, + point_coords, + align_corners=False, + ).squeeze(1) + + losses = { + "loss_mask_ce": sigmoid_cross_entropy_loss(point_logits, point_labels, num_boxes), + "loss_mask_dice": dice_loss(point_logits, point_labels, num_boxes), + } + return losses + + def forward(self, outputs, targets): + """ + This performs the loss computation. + + Args: + outputs (`dict`, *optional*): + Dictionary of tensors, see the output specification of the model for the format. + targets (`list[dict]`, *optional*): + List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the + losses applied, see each loss' doc. + """ + group_detr = self.group_detr if self.training else 1 + outputs_without_aux_and_enc = { + k: v for k, v in outputs.items() if k != "enc_outputs" and k != "auxiliary_outputs" + } + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux_and_enc, targets, group_detr) + + # Compute the average number of target boxes across all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = num_boxes * group_detr + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + world_size = 1 + if is_accelerate_available(): + if PartialState._shared_state != {}: + num_boxes = reduce(num_boxes) + world_size = PartialState().num_processes + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "auxiliary_outputs" in outputs: + for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): + indices = self.matcher(auxiliary_outputs, targets, group_detr) + for loss in self.losses: + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + if "enc_outputs" in outputs: + enc_outputs = outputs["enc_outputs"] + indices = self.matcher(enc_outputs, targets, group_detr=group_detr) + for loss in self.losses: + l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes) + l_dict = {k + "_enc": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +def _set_aux_loss(outputs_class, outputs_coord, outputs_masks): + return [ + {"logits": a, "pred_boxes": b, "pred_masks": c} + for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], outputs_masks[:-1]) + ] + + +def RfDetrForSegmentationLoss( + logits, + labels, + device, + pred_boxes, + pred_masks, + config, + outputs_class=None, + outputs_coord=None, + outputs_masks=None, + enc_outputs_class=None, + enc_outputs_coord=None, + enc_outputs_masks=None, + **kwargs, +): + # First: create the matcher + matcher = RfDetrHungarianMatcher( + class_cost=config.class_cost, + bbox_cost=config.bbox_cost, + giou_cost=config.giou_cost, + mask_point_sample_ratio=config.mask_point_sample_ratio, + cost_mask_class_cost=config.mask_class_loss_coefficient, + cost_mask_dice_cost=config.mask_dice_loss_coefficient, + ) + # Second: create the criterion + losses = ["labels", "boxes", "cardinality", "masks"] + criterion = RfDetrImageLoss( + matcher=matcher, + num_classes=config.num_labels, + focal_alpha=config.focal_alpha, + losses=losses, + group_detr=config.group_detr, + mask_point_sample_ratio=config.mask_point_sample_ratio, + ) + criterion.to(device) + # Third: compute the losses, based on outputs and labels + outputs_loss = {} + auxiliary_outputs = None + outputs_loss["logits"] = logits + outputs_loss["pred_boxes"] = pred_boxes + outputs_loss["pred_masks"] = pred_masks + outputs_loss["enc_outputs"] = { + "logits": enc_outputs_class, + "pred_boxes": enc_outputs_coord, + "pred_masks": enc_outputs_masks, + } + if config.auxiliary_loss: + auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord, outputs_masks) + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + loss_dict = criterion(outputs_loss, labels) + # Fourth: compute total loss, as a weighted sum of the various losses + weight_dict = {"loss_ce": config.class_loss_coefficient, "loss_bbox": config.bbox_loss_coefficient} + weight_dict["loss_giou"] = config.giou_loss_coefficient + weight_dict["loss_mask_ce"] = config.mask_class_loss_coefficient + weight_dict["loss_mask_dice"] = config.mask_dice_loss_coefficient + if config.auxiliary_loss: + aux_weight_dict = {} + for i in range(config.decoder_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict) + return loss, loss_dict, auxiliary_outputs diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..8b734461dbdc 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -22,6 +22,7 @@ from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss from .loss_lw_detr import LwDetrForObjectDetectionLoss +from .loss_rf_detr import RfDetrForSegmentationLoss from .loss_rt_detr import RTDetrForObjectDetectionLoss @@ -165,4 +166,6 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs): "DFineForObjectDetection": DFineForObjectDetectionLoss, "CsmForConditionalGeneration": ForCausalLMLoss, "LwDetrForObjectDetection": LwDetrForObjectDetectionLoss, + "RfDetrForObjectDetection": LwDetrForObjectDetectionLoss, + "RfDetrForInstanceSegmentation": RfDetrForSegmentationLoss, } diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 89361203d5ef..0badb16c8bb5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -327,6 +327,7 @@ from .regnet import * from .rembert import * from .resnet import * + from .rf_detr import * from .roberta import * from .roberta_prelayernorm import * from .roc_bert import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7909e3834aa9..44acb4c6f360 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -366,6 +366,8 @@ ("regnet", "RegNetConfig"), ("rembert", "RemBertConfig"), ("resnet", "ResNetConfig"), + ("rf_detr", "RfDetrConfig"), + ("rf_detr_dinov2", "RfDetrDinov2Config"), ("roberta", "RobertaConfig"), ("roberta-prelayernorm", "RobertaPreLayerNormConfig"), ("roc_bert", "RoCBertConfig"), @@ -844,6 +846,8 @@ ("regnet", "RegNet"), ("rembert", "RemBERT"), ("resnet", "ResNet"), + ("rf_detr", "RF-DETR"), + ("rf_detr_dinov2", "RF-DETR-DINOv2"), ("roberta", "RoBERTa"), ("roberta-prelayernorm", "RoBERTa-PreLayerNorm"), ("roc_bert", "RoCBert"), @@ -1010,6 +1014,7 @@ ("smolvlm_vision", "smolvlm"), ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), + ("rf_detr_dinov2", "rf_detr"), ("granitevision", "llava_next"), ("internvl_vision", "internvl"), ("qwen2_5_vl_text", "qwen2_5_vl"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index e5fe9f4400eb..90aa089f3ec2 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -179,6 +179,7 @@ ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("rf_detr", ("DetrImageProcessor", "DetrImageProcessorFast")), ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), ("sam", ("SamImageProcessor", "SamImageProcessorFast")), ("sam2", (None, "Sam2ImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d9370d78e736..5938b794d36d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -355,6 +355,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("regnet", "RegNetModel"), ("rembert", "RemBertModel"), ("resnet", "ResNetModel"), + ("rf_detr", "RfDetrModel"), ("roberta", "RobertaModel"), ("roberta-prelayernorm", "RobertaPreLayerNormModel"), ("roc_bert", "RoCBertModel"), @@ -865,6 +866,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): # Model for Instance Segmentation mapping # MaskFormerForInstanceSegmentation can be removed from this mapping in v5 ("maskformer", "MaskFormerForInstanceSegmentation"), + ("rf_detr", "RfDetrForInstanceSegmentation"), ] ) @@ -1030,6 +1032,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deformable_detr", "DeformableDetrForObjectDetection"), ("detr", "DetrForObjectDetection"), ("lw_detr", "LwDetrForObjectDetection"), + ("rf_detr", "RfDetrForObjectDetection"), ("rt_detr", "RTDetrForObjectDetection"), ("rt_detr_v2", "RTDetrV2ForObjectDetection"), ("table-transformer", "TableTransformerForObjectDetection"), @@ -1598,6 +1601,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("pixio", "PixioBackbone"), ("pvt_v2", "PvtV2Backbone"), ("resnet", "ResNetBackbone"), + ("rf_detr_dinov2", "RfDetrDinov2Backbone"), ("rt_detr_resnet", "RTDetrResNetBackbone"), ("swin", "SwinBackbone"), ("swinv2", "Swinv2Backbone"), diff --git a/src/transformers/models/lw_detr/configuration_lw_detr.py b/src/transformers/models/lw_detr/configuration_lw_detr.py index 3e90410ccd7c..6a89fe8a2656 100644 --- a/src/transformers/models/lw_detr/configuration_lw_detr.py +++ b/src/transformers/models/lw_detr/configuration_lw_detr.py @@ -233,8 +233,8 @@ class LwDetrConfig(PreTrainedConfig): Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. giou_cost (`float`, *optional*, defaults to 2): Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. - mask_loss_coefficient (`float`, *optional*, defaults to 1): - Relative weight of the Focal loss in the panoptic segmentation loss. + class_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Cross Entropy loss in the object detection loss. dice_loss_coefficient (`float`, *optional*, defaults to 1): Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. bbox_loss_coefficient (`float`, *optional*, defaults to 5): @@ -297,7 +297,7 @@ def __init__( class_cost=2, bbox_cost=5, giou_cost=2, - mask_loss_coefficient=1, + class_loss_coefficient=1, dice_loss_coefficient=1, bbox_loss_coefficient=5, giou_loss_coefficient=2, @@ -362,6 +362,7 @@ def __init__( self.bbox_cost = bbox_cost self.giou_cost = giou_cost # Loss coefficients + self.class_loss_coefficient = class_loss_coefficient self.dice_loss_coefficient = dice_loss_coefficient self.bbox_loss_coefficient = bbox_loss_coefficient self.giou_loss_coefficient = giou_loss_coefficient diff --git a/src/transformers/models/lw_detr/modular_lw_detr.py b/src/transformers/models/lw_detr/modular_lw_detr.py index ac824e990b11..3eead582a85c 100644 --- a/src/transformers/models/lw_detr/modular_lw_detr.py +++ b/src/transformers/models/lw_detr/modular_lw_detr.py @@ -262,8 +262,8 @@ class LwDetrConfig(PreTrainedConfig): Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. giou_cost (`float`, *optional*, defaults to 2): Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. - mask_loss_coefficient (`float`, *optional*, defaults to 1): - Relative weight of the Focal loss in the panoptic segmentation loss. + class_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Cross Entropy loss in the object detection loss. dice_loss_coefficient (`float`, *optional*, defaults to 1): Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. bbox_loss_coefficient (`float`, *optional*, defaults to 5): @@ -326,7 +326,7 @@ def __init__( class_cost=2, bbox_cost=5, giou_cost=2, - mask_loss_coefficient=1, + class_loss_coefficient=1, dice_loss_coefficient=1, bbox_loss_coefficient=5, giou_loss_coefficient=2, @@ -391,6 +391,7 @@ def __init__( self.bbox_cost = bbox_cost self.giou_cost = giou_cost # Loss coefficients + self.class_loss_coefficient = class_loss_coefficient self.dice_loss_coefficient = dice_loss_coefficient self.bbox_loss_coefficient = bbox_loss_coefficient self.giou_loss_coefficient = giou_loss_coefficient diff --git a/src/transformers/models/rf_detr/__init__.py b/src/transformers/models/rf_detr/__init__.py new file mode 100644 index 000000000000..730e55bdf5ed --- /dev/null +++ b/src/transformers/models/rf_detr/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_rf_detr import * + from .modeling_rf_detr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/rf_detr/configuration_rf_detr.py b/src/transformers/models/rf_detr/configuration_rf_detr.py new file mode 100644 index 000000000000..789f43cff535 --- /dev/null +++ b/src/transformers/models/rf_detr/configuration_rf_detr.py @@ -0,0 +1,414 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rf_detr/modular_rf_detr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rf_detr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PreTrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class RfDetrDinov2Config(BackboneConfigMixin, PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RfDetrDinov2Model`]. It is used to instantiate an + RfDetrDinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv2 + [facebook/dinov2-base](https://huggingface.co/facebook/dinov2-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + use_mask_token (`bool`, *optional*, defaults to `True`): + Whether to use mask_token in embeddings. + num_windows (`int`, *optional*, defaults to 4): + Number of windows to use for windowed attention. If 1, no windowed attention is used. + Example: + + ```python + >>> from transformers import RfDetrDinov2Config, RfDetrDinov2Backbone + + >>> # Initializing a RfDetrDinov2 base style configuration + >>> configuration = RfDetrDinov2Config() + + >>> # Initializing a model (with random weights) from the base style configuration + >>> model = RfDetrDinov2Backbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rf_detr_dinov2" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=14, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + use_mask_token=True, + num_windows: int = 4, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + self.use_mask_token = use_mask_token + + self.num_windows = num_windows + window_block_indexes = set(range(self._out_indices[-1] + 1)) + window_block_indexes.difference_update(self._out_indices) + window_block_indexes = list(window_block_indexes) + self.window_block_indexes = window_block_indexes + + +class RfDetrConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RfDetrModel`]. It is used to instantiate + a LW-DETR model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LW-DETR + [stevenbucaille/RfDetr_small_60e_coco](https://huggingface.co/stevenbucaille/RfDetr_small_60e_coco) architecture. + + LW-DETR (Lightweight Detection Transformer) is a transformer-based object detection model designed for real-time + detection tasks. It replaces traditional CNN-based detectors like YOLO with a more efficient transformer architecture + that achieves competitive performance while being computationally lightweight. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. If not provided, will default to `RfDetrDinov2Config` + with a small ViT architecture optimized for detection tasks. + projector_scale_factors (`list[float]`, *optional*, defaults to `[]`): + Scale factors for the feature pyramid network. Each scale factor determines the resolution of features + at different levels. Supported values are 0.5, 1.0, and 2.0. + hidden_expansion (`float`, *optional*, defaults to 0.5): + Expansion factor for hidden dimensions in the projector layers. + c2f_num_blocks (`int`, *optional*, defaults to 3): + Number of blocks in the C2F layer. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the projector. Supported values are `"silu"`, `"relu"`, `"gelu"`. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for layer normalization layers. + d_model (`int`, *optional*, defaults to 256): + Dimension of the model layers and the number of expected features in the decoder inputs. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 3): + Number of decoder layers in the transformer. + decoder_self_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the decoder self-attention. + decoder_cross_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the decoder cross-attention. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the decoder. Supported values are `"relu"`, `"silu"`, `"gelu"`. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`RfDetrModel`] can detect in a single image. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to add bias to the attention layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + group_detr (`int`, *optional*, defaults to 13): + Number of groups for Group DETR attention mechanism, which helps reduce computational complexity. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + disable_custom_kernels (`bool`, *optional*, defaults to `True`): + Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom + kernels are not supported by PyTorch ONNX export. + class_cost (`float`, *optional*, defaults to 2): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + class_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the classification loss in the Hungarian matching cost. + mask_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Focal loss in the instance segmentation mask loss. + dice_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the DICE/F-1 loss in the object detection loss. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + focal_alpha (`float`, *optional*, defaults to 0.25): + Alpha parameter in the focal loss. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + mask_point_sample_ratio (`int`, *optional*, defaults to 16): + The ratio of points to sample for the mask loss calculation. + mask_downsample_ratio (`int`, *optional*, defaults to 4): + The downsample ratio for the segmentation masks compared to the input image resolution. + mask_class_loss_coefficient (`float`, *optional*, defaults to 5.0): + Relative weight of the Focal loss in the instance segmentation loss. + mask_dice_loss_coefficient (`float`, *optional*, defaults to 5.0): + Relative weight of the DICE/F-1 loss in the instance segmentation loss. + segmentation_head_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the segmentation head. Supported values are `"relu"`, `"silu"`, `"gelu"`. + Examples: + + ```python + >>> from transformers import RfDetrConfig, RfDetrModel + + >>> # Initializing a LW-DETR stevenbucaille/RfDetr_small_60e_coco style configuration + >>> configuration = RfDetrConfig() + + >>> # Initializing a model (with random weights) from the stevenbucaille/RfDetr_small_60e_coco style configuration + >>> model = RfDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rf_detr" + sub_configs = {"backbone_config": AutoConfig} + + def __init__( + self, + # backbone + backbone_config=None, + # projector + projector_scale_factors: list[float] = [], + hidden_expansion=0.5, + c2f_num_blocks=3, + activation_function="silu", + layer_norm_eps=1e-5, + # decoder + d_model=256, + dropout=0.1, + decoder_ffn_dim=2048, + decoder_n_points=4, + decoder_layers: int = 3, + decoder_self_attention_heads: int = 8, + decoder_cross_attention_heads: int = 16, + decoder_activation_function="relu", + # model + num_queries=300, + attention_bias=True, + attention_dropout=0.0, + activation_dropout=0.0, + group_detr: int = 13, + init_std=0.02, + disable_custom_kernels=True, + # loss + class_cost=2, + bbox_cost=5, + giou_cost=2, + class_loss_coefficient=1, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.1, + focal_alpha=0.25, + auxiliary_loss=True, + mask_point_sample_ratio=16, + # segmentation + mask_downsample_ratio=4, + mask_class_loss_coefficient=5.0, + mask_dice_loss_coefficient=5.0, + segmentation_head_activation_function="gelu", + **kwargs, + ): + self.layer_norm_eps = layer_norm_eps + + # backbone + if backbone_config is None: + logger.info( + "`backbone_config` is `None`. Initializing the config with the default `RfDetrDinov2` backbone." + ) + backbone_config = RfDetrDinov2Config( + attention_probs_dropout_prob=0.0, + drop_path_rate=0.0, + hidden_act="gelu", + hidden_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-06, + layerscale_value=1.0, + mlp_ratio=4, + num_attention_heads=6, + num_channels=3, + num_hidden_layers=12, + qkv_bias=True, + use_swiglu_ffn=False, + out_features=["stage2", "stage5", "stage8", "stage11"], + hidden_size=384, + patch_size=14, + num_windows=4, + num_register_tokens=0, + image_size=518, + **kwargs, + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + self.backbone_config = backbone_config + + # projector + self.projector_scale_factors = projector_scale_factors + for scale in projector_scale_factors: + if scale not in [0.5, 1.0, 2.0]: + raise ValueError(f"Unsupported scale factor: {scale}") + self.projector_in_channels = [d_model] * len(projector_scale_factors) + self.projector_out_channels = d_model + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + self.c2f_num_blocks = c2f_num_blocks + # decoder + self.d_model = d_model + self.dropout = dropout + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = len(self.projector_scale_factors) + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_activation_function = decoder_activation_function + self.decoder_self_attention_heads = decoder_self_attention_heads + self.decoder_cross_attention_heads = decoder_cross_attention_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + # model + self.init_std = init_std + self.group_detr = group_detr + # Loss + self.auxiliary_loss = auxiliary_loss + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.class_loss_coefficient = class_loss_coefficient + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.mask_class_loss_coefficient = mask_class_loss_coefficient + self.mask_dice_loss_coefficient = mask_dice_loss_coefficient + self.eos_coefficient = eos_coefficient + self.focal_alpha = focal_alpha + self.disable_custom_kernels = disable_custom_kernels + self.mask_point_sample_ratio = mask_point_sample_ratio + # segmentation + self.mask_downsample_ratio = mask_downsample_ratio + self.segmentation_head_activation_function = segmentation_head_activation_function + super().__init__(**kwargs) + + +__all__ = ["RfDetrConfig", "RfDetrDinov2Config"] diff --git a/src/transformers/models/rf_detr/convert_rf_detr_weights_to_hf.py b/src/transformers/models/rf_detr/convert_rf_detr_weights_to_hf.py new file mode 100644 index 000000000000..f6c9f86f0a7a --- /dev/null +++ b/src/transformers/models/rf_detr/convert_rf_detr_weights_to_hf.py @@ -0,0 +1,999 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RF-DETR checkpoints to transformers format.""" + +import argparse + +import requests +import torch +from PIL import Image +from torchvision import transforms + +from transformers import ( + DetrImageProcessor, + RfDetrConfig, + RfDetrForInstanceSegmentation, + RfDetrForObjectDetection, +) +from transformers.core_model_loading import ( + Chunk, + WeightConverter, + WeightRenaming, + convert_and_load_state_dict_in_model, +) + + +# Mapping of model names to their checkpoint files +HOSTED_MODELS = { + "rf-detr-base": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth", + # base-2 is a less converged model that may be better for finetuning but worse for inference + "rf-detr-base-2": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth", + "rf-detr-nano": "https://storage.googleapis.com/rfdetr/nano_coco/checkpoint_best_regular.pth", + "rf-detr-small": "https://storage.googleapis.com/rfdetr/small_coco/checkpoint_best_regular.pth", + "rf-detr-medium": "https://storage.googleapis.com/rfdetr/medium_coco/checkpoint_best_regular.pth", + "rf-detr-large-deprecated": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth", + "rf-detr-large": "https://storage.googleapis.com/rfdetr/rf-detr-large-2026.pth", + "rf-detr-seg-preview": "https://storage.googleapis.com/rfdetr/rf-detr-seg-preview.pt", + "rf-detr-seg-nano": "https://storage.googleapis.com/rfdetr/rf-detr-seg-n-ft.pth", + "rf-detr-seg-small": "https://storage.googleapis.com/rfdetr/rf-detr-seg-s-ft.pth", + "rf-detr-seg-medium": "https://storage.googleapis.com/rfdetr/rf-detr-seg-m-ft.pth", + "rf-detr-seg-large": "https://storage.googleapis.com/rfdetr/rf-detr-seg-l-ft.pth", + "rf-detr-seg-xlarge": "https://storage.googleapis.com/rfdetr/rf-detr-seg-xl-ft.pth", + "rf-detr-seg-xxlarge": "https://storage.googleapis.com/rfdetr/rf-detr-seg-2xl-ft.pth", +} + +# Model configurations for different sizes +BACKBONE_CONFIGS = { + "rf-detr-nano": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 16, + "num_windows": 2, + "image_size": 384, + }, + "rf-detr-small": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 16, + "num_windows": 2, + "image_size": 512, + }, + "rf-detr-base": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage2", "stage5", "stage8", "stage11"], + "hidden_size": 384, + "patch_size": 14, + "num_windows": 4, + "image_size": 518, + }, + "rf-detr-medium": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 16, + "num_windows": 2, + "image_size": 576, + }, + "rf-detr-large-deprecated": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 12, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage2", "stage5", "stage8", "stage11"], + "hidden_size": 768, + "patch_size": 14, + "num_windows": 4, + "image_size": 518, + }, + "rf-detr-large": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 16, + "num_windows": 2, + "image_size": 704, + }, + "rf-detr-xlarge": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 12, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage2", "stage5", "stage8", "stage11"], + "hidden_size": 768, + "patch_size": 20, + "num_windows": 1, + "image_size": 700, + }, + "rf-detr-xxlarge": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 12, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage2", "stage5", "stage8", "stage11"], + "hidden_size": 768, + "patch_size": 20, + "num_windows": 2, + "image_size": 880, + }, + "rf-detr-seg-preview": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 12, + "num_windows": 2, + "image_size": 432, + }, + "rf-detr-seg-nano": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 12, + "num_windows": 1, + "image_size": 312, + }, + "rf-detr-seg-small": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 12, + "num_windows": 2, + "image_size": 384, + }, + "rf-detr-seg-medium": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 12, + "num_windows": 2, + "image_size": 432, + }, + "rf-detr-seg-large": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 12, + "num_windows": 2, + "image_size": 504, + }, + "rf-detr-seg-xlarge": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 12, + "num_windows": 2, + "image_size": 624, + }, + "rf-detr-seg-xxlarge": { + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "qkv_bias": True, + "use_swiglu_ffn": False, + "out_features": ["stage3", "stage6", "stage9", "stage12"], + "hidden_size": 384, + "patch_size": 12, + "num_windows": 2, + "image_size": 768, + }, +} + +MODEL_CONFIGS = { + "rf-detr-nano": { + "d_model": 256, + "decoder_layers": 2, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 300, + }, + "rf-detr-small": { + "d_model": 256, + "decoder_layers": 3, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 300, + }, + "rf-detr-base": { + "d_model": 256, + "decoder_layers": 3, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 300, + }, + "rf-detr-medium": { + "d_model": 256, + "decoder_layers": 4, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 300, + }, + "rf-detr-large-deprecated": { + "d_model": 384, + "num_queries": 300, + "decoder_layers": 3, + "decoder_self_attention_heads": 12, + "decoder_cross_attention_heads": 24, + "decoder_n_points": 4, + "projector_scale_factors": [2.0, 0.5], + }, + "rf-detr-large": { + "d_model": 256, + "num_queries": 300, + "decoder_layers": 4, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + }, + "rf-detr-xlarge": { + "d_model": 512, + "num_queries": 300, + "decoder_layers": 5, + "decoder_self_attention_heads": 16, + "decoder_cross_attention_heads": 32, + "decoder_n_points": 4, + "projector_scale_factors": [2.0, 0.5], + }, + "rf-detr-xxlarge": { + "d_model": 512, + "num_queries": 300, + "decoder_layers": 5, + "decoder_self_attention_heads": 16, + "decoder_cross_attention_heads": 32, + "decoder_n_points": 4, + "projector_scale_factors": [2.0, 0.5], + }, + "rf-detr-seg-preview": { + "d_model": 256, + "decoder_layers": 4, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 200, + "class_loss_coefficient": 5.0, + }, + "rf-detr-seg-nano": { + "d_model": 256, + "decoder_layers": 4, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 100, + "class_loss_coefficient": 5.0, + }, + "rf-detr-seg-small": { + "d_model": 256, + "decoder_layers": 4, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 100, + "class_loss_coefficient": 5.0, + }, + "rf-detr-seg-medium": { + "d_model": 256, + "decoder_layers": 5, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 200, + "class_loss_coefficient": 5.0, + }, + "rf-detr-seg-large": { + "d_model": 256, + "decoder_layers": 5, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 300, + "class_loss_coefficient": 5.0, + }, + "rf-detr-seg-xlarge": { + "d_model": 256, + "decoder_layers": 6, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 300, + "class_loss_coefficient": 5.0, + }, + "rf-detr-seg-xxlarge": { + "d_model": 256, + "decoder_layers": 6, + "decoder_self_attention_heads": 8, + "decoder_cross_attention_heads": 16, + "decoder_n_points": 2, + "projector_scale_factors": [1.0], + "num_queries": 300, + "class_loss_coefficient": 5.0, + }, +} + +IMAGE_PROCESSORS = { + "rf-detr-nano": { + "do_resize": True, + "size": (384, 384), + }, + "rf-detr-small": { + "do_resize": True, + "size": (512, 512), + }, + "rf-detr-base": { + "do_resize": True, + "size": (560, 560), + }, + "rf-detr-medium": { + "do_resize": True, + "size": (576, 576), + }, + "rf-detr-large-deprecated": { + "do_resize": True, + "size": (560, 560), + }, + "rf-detr-large": { + "do_resize": True, + "size": (704, 704), + }, + "rf-detr-xlarge": { + "do_resize": True, + "size": (560, 560), + }, + "rf-detr-xxlarge": { + "do_resize": True, + "size": (560, 560), + }, + "rf-detr-seg-preview": { + "do_resize": True, + "size": (432, 432), + }, + "rf-detr-seg-nano": { + "do_resize": True, + "size": (312, 312), + }, + "rf-detr-seg-small": { + "do_resize": True, + "size": (384, 384), + }, + "rf-detr-seg-medium": { + "do_resize": True, + "size": (432, 432), + }, + "rf-detr-seg-large": { + "do_resize": True, + "size": (504, 504), + }, + "rf-detr-seg-xlarge": { + "do_resize": True, + "size": (624, 624), + }, + "rf-detr-seg-xxlarge": { + "do_resize": True, + "size": (768, 768), + }, +} + + +def get_model_config(model_name: str): + """Get the appropriate configuration for a given model size.""" + config = None + image_processor_config = None + sizes = MODEL_CONFIGS.keys() + for size in sizes: + if size == model_name: + config = MODEL_CONFIGS[size] + config["backbone_config"] = BACKBONE_CONFIGS[size] + image_processor_config = IMAGE_PROCESSORS[size] + break + + # Default to base configuration + if config is None: + config = MODEL_CONFIGS["base"] + config["backbone_config"] = BACKBONE_CONFIGS["base"] + image_processor_config = IMAGE_PROCESSORS["base"] + config["backbone_config"]["model_type"] = "rf_detr_dinov2" + + if "objects365" in model_name: + config["num_labels"] = 366 + elif "coco" in model_name: + config["num_labels"] = 91 + else: + config["num_labels"] = 91 + + return config, image_processor_config + + +def get_weight_mapping( + rf_detr_config: RfDetrConfig, + is_segmentation: bool, +) -> list[WeightConverter | WeightRenaming]: + if is_segmentation: + weight_mapping = [ + # backbone RfDetrConvEncoder + WeightRenaming("backbone.0.encoder.encoder", "rf_detr.model.backbone.backbone"), + WeightRenaming("backbone.0.projector", "rf_detr.model.backbone.projector"), + # RfDetrDecoder + WeightRenaming("transformer.decoder", "rf_detr.model.decoder"), + # RfDetrForObjectDetection + WeightRenaming(r"transformer.enc_out_bbox_embed", r"rf_detr.model.enc_out_bbox_embed"), + WeightRenaming(r"transformer.enc_output.(\d+)", r"rf_detr.model.enc_output.\1"), + WeightRenaming(r"transformer.enc_output_norm.(\d+)", r"rf_detr.model.enc_output_norm.\1"), + WeightRenaming(r"transformer.enc_out_class_embed.(\d+)", r"rf_detr.model.enc_out_class_embed.\1"), + WeightRenaming(r"refpoint_embed.weight", r"rf_detr.model.reference_point_embed.weight"), + ] + else: + weight_mapping = [ + # backbone RfDetrConvEncoder + WeightRenaming("backbone.0.encoder.encoder", "backbone.backbone"), + WeightRenaming("backbone.0.projector", "backbone.projector"), + # RfDetrDecoder + WeightRenaming("transformer.decoder", "decoder"), + # RfDetrForObjectDetection + WeightRenaming("transformer.enc_out_bbox_embed", "enc_out_bbox_embed"), + WeightRenaming(r"transformer.enc_output.(\d+)", r"enc_output.\1"), + WeightRenaming(r"transformer.enc_output_norm.(\d+)", r"enc_output_norm.\1"), + WeightRenaming(r"transformer.enc_out_class_embed.(\d+)", r"enc_out_class_embed.\1"), + WeightRenaming(r"refpoint_embed.weight", r"reference_point_embed.weight"), + ] + + weight_mapping.extend( + [ + # RfDetrConvEncoder + ## RfDetrMultiScaleProjector + WeightRenaming(r"projector.stages_sampling.(\d+)", r"projector.scale_layers.\1.sampling_layers"), + WeightRenaming(r"projector.stages.(\d+).0", r"projector.scale_layers.\1.projector_layer"), + WeightRenaming(r"projector.stages.(\d+).1", r"projector.scale_layers.\1.layer_norm"), + ## RfDetrSamplingLayer + WeightRenaming(r"sampling_layers.(\d+)", r"sampling_layers.\1.layers"), + WeightRenaming(r"layers.(\d+).bn", r"layers.\1.norm"), + ### RfDetrC2FLayer + WeightRenaming(r"projector_layer.cv1.conv", r"projector_layer.conv1.conv"), + WeightRenaming(r"projector_layer.cv1.bn", r"projector_layer.conv1.norm"), + WeightRenaming(r"projector_layer.cv2.conv", r"projector_layer.conv2.conv"), + WeightRenaming(r"projector_layer.cv2.bn", r"projector_layer.conv2.norm"), + WeightRenaming(r"projector_layer.m.(\d+)", r"projector_layer.bottlenecks.\1"), + #### RfDetrRepVggBlock + WeightRenaming(r"bottlenecks.(\d+).cv1.conv", r"bottlenecks.\1.conv1.conv"), + WeightRenaming(r"bottlenecks.(\d+).cv1.bn", r"bottlenecks.\1.conv1.norm"), + WeightRenaming(r"bottlenecks.(\d+).cv2.conv", r"bottlenecks.\1.conv2.conv"), + WeightRenaming(r"bottlenecks.(\d+).cv2.bn", r"bottlenecks.\1.conv2.norm"), + # RfDetrDecoder + ## RfDetrDecoderLayer + WeightRenaming(r"decoder.layers.(\d+).norm1", r"decoder.layers.\1.self_attn_layer_norm"), + WeightRenaming(r"decoder.layers.(\d+).norm2", r"decoder.layers.\1.cross_attn_layer_norm"), + WeightRenaming(r"decoder.layers.(\d+).linear1", r"decoder.layers.\1.mlp.fc1"), + WeightRenaming(r"decoder.layers.(\d+).linear2", r"decoder.layers.\1.mlp.fc2"), + WeightRenaming(r"decoder.layers.(\d+).norm3", r"decoder.layers.\1.layer_norm"), + WeightRenaming("decoder.norm", r"decoder.layernorm"), + ### RfDetrAttention + WeightRenaming(r"self_attn.out_proj", r"self_attn.o_proj"), + WeightConverter( + r"self_attn.in_proj_bias", + [r"self_attn.q_proj.bias", r"self_attn.k_proj.bias", r"self_attn.v_proj.bias"], + operations=[Chunk(dim=0)], + ), + WeightConverter( + r"self_attn.in_proj_weight", + [r"self_attn.q_proj.weight", r"self_attn.k_proj.weight", r"self_attn.v_proj.weight"], + operations=[Chunk(dim=0)], + ), + ] + ) + + # Indices depend on the value of projector_scale_factors + for i, scale in enumerate(rf_detr_config.projector_scale_factors): + if scale == 2.0: + weight_mapping.append( + WeightRenaming( + rf"projector.stages_sampling.{i}.(\d+).(\d+)", + rf"projector.scale_layers.{i}.sampling_layers.\1.layers.\2", + ) + ) + elif scale == 0.5: + weight_mapping.append( + WeightRenaming( + rf"projector.stages_sampling.{i}.(\d+).(\d+).conv.weight", + rf"projector.scale_layers.{i}.sampling_layers.\1.layers.\2.conv.weight", + ) + ) + weight_mapping.append( + WeightRenaming( + rf"projector.stages_sampling.{i}.(\d+).(\d+).bn", + rf"projector.scale_layers.{i}.sampling_layers.\1.layers.\2.norm", + ) + ) + + if is_segmentation: + weight_mapping.extend( + [ + # RfDetrForObjectDetection + WeightRenaming(r"bbox_embed.layers", "rf_detr.bbox_embed.layers"), + WeightRenaming(r"class_embed.(weight|bias)", r"rf_detr.class_embed.\1"), + WeightRenaming(r"query_feat.(weight|bias)", r"rf_detr.model.query_feat.\1"), + # Segmentation head + WeightRenaming(r"segmentation_head.blocks", r"blocks"), + WeightRenaming("segmentation_head.spatial_features_proj", "spatial_features_proj"), + WeightRenaming("segmentation_head.query_features_block", "query_features_block"), + WeightRenaming("segmentation_head.query_features_proj", "query_features_proj"), + WeightRenaming("segmentation_head.bias", "bias"), + ## RfDetrSegmentationMLPBlock + WeightRenaming("query_features_block.layers.0", "query_features_block.in_linear"), + WeightRenaming("query_features_block.layers.2", "query_features_block.out_linear"), + ## list[RfDetrSegmentationBlock] + WeightRenaming(r"blocks.(\d+)", r"blocks.\1"), + WeightRenaming(r"blocks.(\d+).norm", r"blocks.\1.layernorm"), + ] + ) + return weight_mapping + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + + return im + + +def original_preprocess_image(image, size): + transform = transforms.Compose( + [ + transforms.Resize(list(size.values())), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image = transform(image) + return image + + +def test_models_outputs( + model: RfDetrForObjectDetection | RfDetrForInstanceSegmentation, + image_processor: DetrImageProcessor, + model_name: str, +): + expected_outputs = { + "rf-detr-nano": { + "logits": [-6.68004, -5.66107, -11.70373, -8.32324, -5.76176], + "boxes": [0.25828, 0.54991, 0.47220, 0.87432, 0.55099], + "loss": 14.893259, + }, + "rf-detr-small": { + "logits": [-6.83893, -4.55097, -10.53040, -8.20657, -5.55314], + "boxes": [0.25782, 0.55037, 0.47922, 0.87102, 0.77074], + "loss": 19.771887, + }, + "rf-detr-base": { + "logits": [-7.60410, -4.65943, -10.03144, -5.63881, -9.88291], + "boxes": [0.25465, 0.54864, 0.48583, 0.86991, 0.16926], + "loss": 21.967346, + }, + "rf-detr-base-2": { + "logits": [-6.81648, -6.80946, -7.72004, -6.06710, -10.37419], + "boxes": [0.16911, 0.19784, 0.21076, 0.09273, 0.25263], + "loss": 21.532478, + }, + "rf-detr-medium": { + "logits": [-6.58581, -8.07883, -12.52392, -7.78248, -10.47323], + "boxes": [0.16824, 0.19932, 0.21110, 0.09385, 0.77087], + "loss": 26.337656, + }, + "rf-detr-large-deprecated": { + "logits": [-7.60887, -4.36907, -4.98866, -8.06598, -5.52969], + "boxes": [0.25576, 0.55051, 0.47765, 0.87141, 0.76966], + "loss": 22.116581, + }, + "rf-detr-large": { + "logits": [-6.38973, -8.19355, -12.09174, -7.80438, -10.15835], + "boxes": [0.16901, 0.19936, 0.21087, 0.09311, 0.77199], + "loss": 27.111633, + }, + "rf-detr-seg-preview": { + "logits": [-7.05877, -4.23362, -6.54288, -8.13384, -6.36838], + "boxes": [0.25603, 0.55164, 0.48340, 0.87798, 0.73861], + "pred_masks": [-16.72129, -16.17153, -17.06426, -17.29409, -17.78559], + "loss": 76.156754, + }, + "rf-detr-seg-nano": { + "logits": [-7.38427, -5.59449, -9.97889, -11.03668, -8.62285], + "boxes": [0.25230, 0.54825, 0.48196, 0.86925, 0.77119], + "pred_masks": [-12.01641, -12.37785, -13.37312, -13.54168, -13.53435], + "loss": 92.973061, + }, + "rf-detr-seg-small": { + "logits": [-7.35031, -5.09690, -9.58117, -10.80274, -8.35001], + "boxes": [0.25607, 0.54820, 0.48018, 0.87013, 0.90797], + "pred_masks": [-13.17243, -13.12057, -13.92742, -13.89896, -13.72802], + "loss": 87.512894, + }, + "rf-detr-seg-medium": { + "logits": [-7.48751, -5.21394, -9.35906, -9.31897, -8.08021], + "boxes": [0.76891, 0.41680, 0.46182, 0.72004, 0.16810], + "pred_masks": [-15.67913, -17.05902, -16.72426, -17.19833, -17.18960], + "loss": 95.562599, + }, + "rf-detr-seg-large": { + "logits": [-7.37005, -5.04871, -9.19777, -9.37870, -7.96562], + "boxes": [0.76796, 0.41489, 0.46220, 0.72197, 0.25254], + "pred_masks": [-15.13846, -16.88754, -16.55486, -17.23686, -17.40160], + "loss": 91.781540, + }, + "rf-detr-seg-xlarge": { + "logits": [-7.42486, -4.72502, -8.16429, -8.30500, -7.21668], + "boxes": [0.76863, 0.41618, 0.46055, 0.72461, 0.16735], + "pred_masks": [-15.15330, -17.61085, -16.79776, -17.33611, -17.39120], + "loss": 105.279922, + }, + "rf-detr-seg-xxlarge": { + "logits": [-7.33242, -5.11294, -6.31125, -7.06520, -7.07922], + "boxes": [0.25516, 0.53685, 0.49769, 0.88601, 0.76872], + "pred_masks": [-7.94849, -8.57010, -8.60056, -8.92854, -8.99288], + "loss": 99.136574, + }, + } + + device = "cuda" if torch.cuda.is_available() else "cpu" + image = prepare_img() + # Fake annotation for testing + annotations = { + "image_id": 0, + "annotations": [ + { + "bbox": [250, 250, 350, 350], + "category_id": 0, + "iscrowd": 0, + "area": 122500, + "segments": [[0, 0, 0, 100, 100, 100, 100, 0]], + } + ], + } + is_segmentation = "segmentation" in model_name + + original_pixel_values = original_preprocess_image(image, image_processor.size).unsqueeze(0).to(device) + inputs = image_processor(images=image, annotations=annotations, return_tensors="pt").to(device) + inputs["labels"][0]["masks"] = torch.zeros((1, original_pixel_values.shape[-1], original_pixel_values.shape[-2])) + torch.testing.assert_close(original_pixel_values, inputs["pixel_values"], atol=1e-6, rtol=1e-6) + print("Preprocessing looks ok!") + + model.to(device) + model.eval() + model.config._attn_implementation = "eager" + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_logits = outputs.logits.flatten()[:5] + expected_logits = expected_outputs[model_name]["logits"] + predicted_boxes = outputs.pred_boxes.flatten()[:5] + expected_boxes = expected_outputs[model_name]["boxes"] + torch.testing.assert_close(predicted_logits, torch.Tensor(expected_logits), rtol=5e-3, atol=5e-3) + torch.testing.assert_close(predicted_boxes, torch.Tensor(expected_boxes), rtol=5e-3, atol=5e-3) + + if is_segmentation: + predicted_mask_logits = outputs.pred_masks.flatten()[:5] + expected_mask_logits = expected_outputs[model_name]["pred_masks"] + torch.testing.assert_close(predicted_mask_logits, torch.Tensor(expected_mask_logits), rtol=5e-3, atol=5e-3) + + predicted_loss = outputs.loss + expected_loss = expected_outputs[model_name]["loss"] + torch.testing.assert_close(predicted_loss, torch.tensor(expected_loss), rtol=5e-3, atol=5e-3) + + print("Forward pass looks ok!") + + +@torch.no_grad() +def convert_rf_detr_checkpoint( + model_name: str, + checkpoint_url: str, + pytorch_dump_folder_path: str, + push_to_hub: bool = False, + organization: str = "stevenbucaille", +): + """ + Convert a RF-DETR checkpoint to HuggingFace format. + + Args: + model_name: Name of the model (e.g., "lwdetr_tiny_30e_objects365") + checkpoint_path: Path to the checkpoint file + pytorch_dump_folder_path: Path to save the converted model + push_to_hub: Whether to push the model to the hub + organization: Organization to push the model to + """ + print(f"Converting {model_name} checkpoint...") + + # Get model configuration + config, image_processor_config = get_model_config(model_name) + rf_detr_config = RfDetrConfig(**config) + + # Load checkpoint + checkpoint_url = checkpoint_url if checkpoint_url is not None else HOSTED_MODELS[model_name] + print(f"Loading checkpoint from {checkpoint_url}...") + checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", weights_only=False) + # Create model and load weights + print("Creating model and loading weights...") + is_segmentation = "seg" in model_name + if is_segmentation: + model = RfDetrForInstanceSegmentation(rf_detr_config) + else: + model = RfDetrForObjectDetection(rf_detr_config) + + weight_mapping = get_weight_mapping(rf_detr_config, is_segmentation) + + # Handle different checkpoint formats + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model( + model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None + ) + print("Checkpoint loaded...") + if len(mismatch) > 0 or len(unexpected) > 0 or len(mismatch) > 0: + print("MISSING:", len(missing)) + print("\n".join(sorted(missing))) + print("UNEXPECTED:", len(unexpected)) + print("\n".join(sorted(unexpected))) + print("MISMATCH:", len(mismatch)) + print(mismatch) + + image_processor = DetrImageProcessor(**image_processor_config, use_fast=True) + test_models_outputs(model, image_processor, model_name) + + repo_id = f"{organization}/{model_name}" + # Save model + print("Saving model..." + " and pushing to hub..." if push_to_hub else "") + model.save_pretrained( + pytorch_dump_folder_path, + save_original_format=False, + push_to_hub=push_to_hub, + repo_id=repo_id, + commit_message=f"Add {model_name} model", + ) + + # Save image processor + print("Saving image processor..." + " and pushing to hub..." if push_to_hub else "") + image_processor.save_pretrained( + pytorch_dump_folder_path, + push_to_hub=push_to_hub, + repo_id=repo_id, + commit_message=f"Add {model_name} image processor", + ) + + # Save config + print("Saving config..." + " and pushing to hub..." if push_to_hub else "") + rf_detr_config.save_pretrained( + pytorch_dump_folder_path, push_to_hub=push_to_hub, repo_id=repo_id, commit_message=f"Add {model_name} config" + ) + + if push_to_hub: + print("Pushed model to hub successfully!") + + print(f"Conversion completed successfully for {model_name}!") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + required=True, + choices=list(HOSTED_MODELS.keys()), + help="Name of the model to convert", + ) + parser.add_argument( + "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory" + ) + parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint file (if not using hub download)") + parser.add_argument("--push_to_hub", action="store_true", help="Push model to the hub") + parser.add_argument("--organization", type=str, default="stevenbucaille", help="Organization to push the model to") + + args = parser.parse_args() + + # Get checkpoint path + checkpoint_path = args.checkpoint_path + + # Convert checkpoint + convert_rf_detr_checkpoint( + model_name=args.model_name, + checkpoint_url=checkpoint_path, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + push_to_hub=args.push_to_hub, + organization=args.organization, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/rf_detr/modeling_rf_detr.py b/src/transformers/models/rf_detr/modeling_rf_detr.py new file mode 100644 index 000000000000..c350ec634e73 --- /dev/null +++ b/src/transformers/models/rf_detr/modeling_rf_detr.py @@ -0,0 +1,2133 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rf_detr/modular_rf_detr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rf_detr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections.abc +import math +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from ... import initialization as init +from ...activations import ACT2CLS, ACT2FN +from ...integrations import use_kernel_forward_from_hub +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput, BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import meshgrid +from ...utils import auto_docstring, torch_compilable_check, torch_int +from ...utils.backbone_utils import BackboneMixin +from ...utils.generic import ModelOutput, TransformersKwargs, can_return_tuple, check_model_inputs +from .configuration_rf_detr import RfDetrConfig, RfDetrDinov2Config + + +class RfDetrDinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +def window_partition( + embeddings: torch.Tensor, num_windows: int, patch_size: int, height: int, width: int +) -> torch.Tensor: + batch_size = embeddings.shape[0] + num_h_patches = height // patch_size + num_w_patches = width // patch_size + cls_token_with_pos_embed = embeddings[:, :1] + pixel_tokens_with_pos_embed = embeddings[:, 1:] + pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(batch_size, num_h_patches, num_w_patches, -1) + num_w_patches_per_window = num_w_patches // num_windows + num_h_patches_per_window = num_h_patches // num_windows + windowed_pixel_tokens = pixel_tokens_with_pos_embed.view( + batch_size, num_windows, num_h_patches_per_window, num_windows, num_h_patches_per_window, -1 + ) + windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 1, 3, 2, 4, 5) + windowed_pixel_tokens = windowed_pixel_tokens.reshape( + batch_size * num_windows**2, num_h_patches_per_window * num_w_patches_per_window, -1 + ) + windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows**2, 1, 1) + embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1) + return embeddings + + +class RfDetrDinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: RfDetrDinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.patch_embeddings = RfDetrDinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.use_mask_token = config.use_mask_token + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(torch.float32), + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + antialias=True, + ).to(dtype=target_dtype) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + if self.config.num_windows > 1: + # reshape for windows + embeddings = window_partition(embeddings, self.config.num_windows, self.config.patch_size, height, width) + embeddings = self.dropout(embeddings) + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class RfDetrDinov2SelfAttention(nn.Module): + def __init__(self, config: RfDetrDinov2Config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.num_key_value_groups = 1 + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size + + key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) + query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + return context_layer, attention_probs + + +class RfDetrDinov2SelfOutput(nn.Module): + """ + The residual connection is defined in RfDetrDinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: RfDetrDinov2Config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class RfDetrDinov2Attention(nn.Module): + def __init__(self, config: RfDetrDinov2Config): + super().__init__() + self.attention = RfDetrDinov2SelfAttention(config) + self.output = RfDetrDinov2SelfOutput(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + self_attn_output, _ = self.attention(hidden_states) + output = self.output(self_attn_output, hidden_states) + return output + + +class RfDetrDinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class RfDetrDinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float | None = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class RfDetrDinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class RfDetrDinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +def window_unpartition_before_attention(hidden_states: torch.Tensor, num_windows: int) -> torch.Tensor: + batch_size, seq_len, channels = hidden_states.shape + num_windows_squared = num_windows**2 + hidden_states = hidden_states.view(batch_size // num_windows_squared, num_windows_squared * seq_len, channels) + return hidden_states + + +def window_partition_after_attention( + hidden_states: torch.Tensor, self_attention_output: torch.Tensor, num_windows: int +) -> torch.Tensor: + batch_size, seq_len, channels = hidden_states.shape + num_windows_squared = num_windows**2 + self_attention_output = self_attention_output.view( + batch_size * num_windows_squared, seq_len // num_windows_squared, channels + ) + return self_attention_output + + +class RfDetrDinov2Layer(GradientCheckpointingLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: RfDetrDinov2Config, layer_idx: int) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = RfDetrDinov2Attention(config) + self.layer_scale1 = RfDetrDinov2LayerScale(config) + self.drop_path = RfDetrDinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = RfDetrDinov2SwiGLUFFN(config) + else: + self.mlp = RfDetrDinov2MLP(config) + self.layer_scale2 = RfDetrDinov2LayerScale(config) + self.num_windows = config.num_windows + self.global_attention = layer_idx not in config.window_block_indexes + + def forward( + self, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]: + shortcut = hidden_states + if self.global_attention: + hidden_states = window_unpartition_before_attention(hidden_states, self.num_windows) + + hidden_states_norm = self.norm1(hidden_states) + self_attention_output = self.attention(hidden_states_norm) + + if self.global_attention: + self_attention_output = window_partition_after_attention( + hidden_states, self_attention_output, self.num_windows + ) + + self_attention_output = self.layer_scale1(self_attention_output) + + # first residual connection + hidden_states = self.drop_path(self_attention_output) + shortcut + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + return layer_output + + +class RfDetrDinov2Encoder(nn.Module): + def __init__(self, config: RfDetrDinov2Config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RfDetrDinov2Layer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, output_hidden_states: bool = False) -> BaseModelOutput: + all_hidden_states = [hidden_states] if output_hidden_states else None + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states) + if all_hidden_states: + all_hidden_states.append(hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, + ) + + +@auto_docstring +class RfDetrDinov2PreTrainedModel(PreTrainedModel): + config: RfDetrDinov2Config + base_model_prefix = "rf_detr_dinov2" + main_input_name = "pixel_values" + input_modalities = ("image",) + supports_gradient_checkpointing = True + _no_split_modules = ["RfDetrDinov2Layer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "attentions": RfDetrDinov2SelfAttention, + } + + @torch.no_grad() + def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + init.zeros_(module.bias) + init.ones_(module.weight) + elif isinstance(module, RfDetrDinov2Embeddings): + init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + if self.config.use_mask_token: + init.zeros_(module.mask_token) + elif isinstance(module, RfDetrDinov2LayerScale): + init.constant_(module.lambda1, self.config.layerscale_value) + + +def window_unpartition( + hidden_state: torch.Tensor, + num_windows: int, + num_h_patches: int, + num_w_patches: int, +) -> torch.Tensor: + hidden_batch_size, seq_len, channels = hidden_state.shape + num_windows_squared = num_windows**2 + num_h_patches_per_window = num_h_patches // num_windows + num_w_patches_per_window = num_w_patches // num_windows + hidden_state = hidden_state.reshape( + hidden_batch_size // num_windows_squared, num_windows_squared * seq_len, channels + ) + hidden_state = hidden_state.view( + hidden_batch_size // num_windows_squared, + num_windows, + num_windows, + num_h_patches_per_window, + num_w_patches_per_window, + channels, + ) + hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5) + return hidden_state + + +@auto_docstring( + custom_intro=""" + RfDetrDinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class RfDetrDinov2Backbone(RfDetrDinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = RfDetrDinov2Embeddings(config) + self.encoder = RfDetrDinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> RfDetrDinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @check_model_inputs + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + embedding_output = self.embeddings(pixel_values) + + output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True) + hidden_states = output.hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + + num_h_patches = height // patch_size + num_w_patches = width // patch_size + + if self.config.num_windows > 1: + hidden_state = window_unpartition( + hidden_state, self.config.num_windows, num_h_patches, num_w_patches + ) + + hidden_state = hidden_state.reshape(batch_size, num_h_patches, num_w_patches, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + + feature_maps += (hidden_state,) + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=hidden_states if output_hidden_states else None, + ) + + +class RfDetrLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +class RfDetrConvNormLayer(nn.Module): + def __init__( + self, + config: RfDetrConfig, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + activation: str | None = None, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=kernel_size // 2, + bias=False, + ) + self.norm = RfDetrLayerNorm(out_channels, data_format="channels_first", eps=config.layer_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RfDetrRepVggBlock(nn.Module): + def __init__(self, config: RfDetrConfig): + super().__init__() + hidden_channels = int(config.d_model * config.hidden_expansion) + self.conv1 = RfDetrConvNormLayer( + config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function + ) + self.conv2 = RfDetrConvNormLayer( + config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = self.conv1(x) + y = self.conv2(y) + return y + + +class RfDetrC2FLayer(nn.Module): + # Inspired by RTDetrCSPRepLayer + def __init__(self, config: RfDetrConfig, in_channels: int): + super().__init__() + num_blocks = config.c2f_num_blocks + activation = config.activation_function + out_channels = config.d_model + + self.hidden_channels = int(out_channels * config.hidden_expansion) + + conv1_out_channels = 2 * self.hidden_channels + self.conv1 = RfDetrConvNormLayer(config, in_channels, conv1_out_channels, 1, 1, activation=activation) + + conv2_in_channels = (2 + num_blocks) * self.hidden_channels + self.conv2 = RfDetrConvNormLayer(config, conv2_in_channels, out_channels, 1, 1, activation=activation) + + self.bottlenecks = nn.ModuleList(RfDetrRepVggBlock(config) for _ in range(num_blocks)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv1(hidden_states) + all_hidden_states = list(hidden_states.split(self.hidden_channels, 1)) + hidden_states = all_hidden_states[-1] + + for bottleneck in self.bottlenecks: + hidden_states = bottleneck(hidden_states) + all_hidden_states.append(hidden_states) + + hidden_states = torch.cat(all_hidden_states, 1) + hidden_states = self.conv2(hidden_states) + return hidden_states + + +class RfDetrSamplingLayer(nn.Module): + def __init__(self, config: RfDetrConfig, channel_size: int, scale: float): + super().__init__() + + self.scale = scale + self.channel_size = channel_size + + layers = [] + if scale == 2.0: + layers.append(nn.ConvTranspose2d(channel_size, channel_size // 2, 2, 2)) + elif scale == 0.5: + layers.append(RfDetrConvNormLayer(config, channel_size, channel_size, 3, 2, activation="relu")) + self.layers = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class RfDetrScaleProjector(nn.Module): + def __init__(self, config: RfDetrConfig, scale: float): + super().__init__() + + intermediate_dims = [config.backbone_config.hidden_size] * len(config.backbone_config.out_indices) + sampling_layers = [] + for channel_size in intermediate_dims: + sampling_layers.append(RfDetrSamplingLayer(config, channel_size, scale)) + self.sampling_layers = nn.ModuleList(sampling_layers) + + intermediate_dim = intermediate_dims[-1] + if scale == 2.0: + intermediate_dim = intermediate_dim // 2 + projector_input_dim = intermediate_dim * len(intermediate_dims) + + self.projector_layer = RfDetrC2FLayer(config, projector_input_dim) + self.layer_norm = RfDetrLayerNorm(config.d_model, data_format="channels_first") + + def forward(self, hidden_states_tuple: tuple[torch.Tensor]) -> torch.Tensor: + sampled_hidden_states = [] + for sampling_layer, hidden_states in zip(self.sampling_layers, hidden_states_tuple): + hidden_states = sampling_layer(hidden_states) + sampled_hidden_states.append(hidden_states) + hidden_states = torch.cat(sampled_hidden_states, dim=1) + hidden_states = self.projector_layer(hidden_states) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class RfDetrMultiScaleProjector(nn.Module): + def __init__(self, config: RfDetrConfig): + super().__init__() + + self.config = config + scale_factors = config.projector_scale_factors + + self.scale_layers = nn.ModuleList([RfDetrScaleProjector(config, scale) for scale in scale_factors]) + + def forward(self, hidden_states: tuple[torch.Tensor]) -> list[torch.Tensor]: + output_hidden_states = [] + for scale_layer in self.scale_layers: + output_hidden_states.append(scale_layer(hidden_states)) + return output_hidden_states + + +class RfDetrConvEncoder(nn.Module): + def __init__(self, config: RfDetrConfig): + super().__init__() + self.backbone = RfDetrDinov2Backbone(config.backbone_config) + self.projector = RfDetrMultiScaleProjector(config) + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.backbone(pixel_values).feature_maps + features = self.projector(features) + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class RfDetrAttention(nn.Module): + def __init__(self, config: RfDetrConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.d_model // config.decoder_self_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.num_key_value_groups = 1 + + self.q_proj = nn.Linear( + config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.decoder_self_attention_heads * self.head_dim, config.d_model, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, _ = hidden_states.shape + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + hidden_states_original = hidden_states + if position_embeddings is not None: + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + + if self.training: + # at training, we use group detr technique to add more supervision by using multiple weight-sharing decoders at once for faster convergence + # at inference, we only use one decoder + hidden_states_original = torch.cat( + hidden_states_original.split(seq_len // self.config.group_detr, dim=1), dim=0 + ) + hidden_states = torch.cat(hidden_states.split(seq_len // self.config.group_detr, dim=1), dim=0) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if self.training: + attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1) + + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +class MultiScaleDeformableAttention(nn.Module): + def forward( + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: list[tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ): + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes_list): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class RfDetrMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, config: RfDetrConfig, num_heads: int, n_points: int): + super().__init__() + + self.attn = MultiScaleDeformableAttention() + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RfDetrMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.disable_custom_kernels = config.disable_custom_kernels + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + **kwargs: Unpack[TransformersKwargs], + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + total_elements = sum(height * width for height, width in spatial_shapes_list) + torch_compilable_check( + total_elements == sequence_length, + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states", + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + num_coordinates = reference_points.shape[-1] + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + + output = self.output_proj(output) + + return output, attention_weights + + +class RfDetrMLP(nn.Module): + def __init__(self, config: RfDetrConfig): + super().__init__() + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + return hidden_states + + +class RfDetrDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: RfDetrConfig, layer_idx: int): + nn.Module.__init__(self) + + # self-attention + self.self_attn = RfDetrAttention(config, layer_idx=layer_idx) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.activation_dropout = config.activation_dropout + self.self_attn_layer_norm = nn.LayerNorm(config.d_model) + + # cross-attention + self.cross_attn = RfDetrMultiscaleDeformableAttention( + config, + num_heads=config.decoder_cross_attention_heads, + n_points=config.decoder_n_points, + ) + self.cross_attn_layer_norm = nn.LayerNorm(config.d_model) + + # mlp + self.mlp = RfDetrMLP(config) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + self_attention_output, self_attn_weights = self.self_attn( + hidden_states, position_embeddings=position_embeddings, **kwargs + ) + + self_attention_output = nn.functional.dropout(self_attention_output, p=self.dropout, training=self.training) + hidden_states = hidden_states + self_attention_output + hidden_states = self.self_attn_layer_norm(hidden_states) + + cross_attention_output, cross_attn_weights = self.cross_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + cross_attention_output = nn.functional.dropout(cross_attention_output, p=self.dropout, training=self.training) + hidden_states = hidden_states + cross_attention_output + hidden_states = self.cross_attn_layer_norm(hidden_states) + + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + return hidden_states + + +@auto_docstring +class RfDetrPreTrainedModel(PreTrainedModel): + config: RfDetrConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + _no_split_modules = [ + r"RfDetrConvEncoder", + r"RfDetrDecoderLayer", + ] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "attentions": [RfDetrAttention, RfDetrMultiscaleDeformableAttention], + "hidden_states": [RfDetrDecoderLayer], + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + + if isinstance(module, RfDetrMultiscaleDeformableAttention): + init.constant_(module.sampling_offsets.weight, 0.0) + thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + + init.copy_(module.sampling_offsets.bias, grid_init.view(-1)) + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + init.xavier_uniform_(module.value_proj.weight) + init.constant_(module.value_proj.bias, 0.0) + init.xavier_uniform_(module.output_proj.weight) + init.constant_(module.output_proj.bias, 0.0) + if hasattr(module, "level_embed"): + init.normal_(module.level_embed) + if hasattr(module, "refpoint_embed") and module.refpoint_embed is not None: + init.constant_(module.refpoint_embed.weight, 0) + if hasattr(module, "class_embed") and module.class_embed is not None: + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + init.constant_(module.class_embed.bias, bias_value) + if hasattr(module, "bbox_embed") and module.bbox_embed is not None: + init.constant_(module.bbox_embed.layers[-1].weight, 0) + init.constant_(module.bbox_embed.layers[-1].bias, 0) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RfDetr backbone-decoder model. + """ +) +class RfDetrModelOutput(ModelOutput): + r""" + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`): + Features from the backbone. + """ + + init_reference_points: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + + backbone_features: list[torch.Tensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RfDetrDecoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + """ +) +class RfDetrDecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + + +# function to generate sine positional embedding for 4d coordinates +def gen_sine_position_embeddings(pos_tensor, hidden_size=256): + """ + This function computes position embeddings using sine and cosine functions from the input positional tensor, + which has a shape of (batch_size, num_queries, 4). + The last dimension of `pos_tensor` represents the following coordinates: + - 0: x-coord + - 1: y-coord + - 2: width + - 3: height + + The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension + achieved by concatenating the sine and cosine values for each coordinate. + """ + scale = 2 * math.pi + dim = hidden_size // 2 + dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + return pos.to(pos_tensor.dtype) + + +class RfDetrMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class RfDetrDecoder(RfDetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and deformable cross-attention layers. + + Some tweaks for RfDetr: + + - it uses group detr technique at training for faster convergence. + + Args: + config: RfDetrConfig + """ + + def __init__(self, config: RfDetrConfig): + super().__init__(config) + self.dropout = config.dropout + self.layers = nn.ModuleList([RfDetrDecoderLayer(config, i) for i in range(config.decoder_layers)]) + self.layernorm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + + self.ref_point_head = RfDetrMLPPredictionHead(2 * config.d_model, config.d_model, config.d_model, num_layers=2) + + self.post_init() + + def get_reference(self, reference_points, valid_ratios): + # batch_size, num_queries, batch_size, 4 + obj_center = reference_points[..., :4] + + # batch_size, num_queries, num_levels, 4 + reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + + # batch_size, num_queries, d_model * 2 + query_sine_embed = gen_sine_position_embeddings(reference_points_inputs[:, :, 0, :], self.config.d_model) + + # batch_size, num_queries, d_model + query_pos = self.ref_point_head(query_sine_embed) + return reference_points_inputs, query_pos + + def forward( + self, + inputs_embeds: torch.Tensor | None = None, + reference_points: torch.Tensor | None = None, + spatial_shapes: torch.Tensor | None = None, + spatial_shapes_list: torch.Tensor | None = None, + level_start_index: torch.Tensor | None = None, + valid_ratios: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + intermediate = () + intermediate_reference_points = (reference_points,) + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios) + + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_embeddings=query_pos, + reference_points=reference_points_inputs, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + intermediate_hidden_states = self.layernorm(hidden_states) + intermediate += (intermediate_hidden_states,) + + intermediate = torch.stack(intermediate) + last_hidden_state = intermediate[-1] + intermediate_reference_points = torch.stack(intermediate_reference_points) + + return RfDetrDecoderOutput( + last_hidden_state=last_hidden_state, + intermediate_hidden_states=intermediate, + intermediate_reference_points=intermediate_reference_points, + ) + + +def refine_bboxes(reference_points, deltas): + reference_points = reference_points.to(deltas.device) + new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2] + new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:] + new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1) + return new_reference_points + + +@auto_docstring( + custom_intro=""" + The bare LW Detr Model (consisting of a backbone and decoder Transformer) outputting raw + hidden-states without any specific head on top. + """ +) +class RfDetrModel(RfDetrPreTrainedModel): + def __init__(self, config: RfDetrConfig): + super().__init__(config) + + # Create backbone + positional encoding + self.backbone = RfDetrConvEncoder(config) + + self.group_detr = config.group_detr + self.num_queries = config.num_queries + hidden_dim = config.d_model + self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4) + self.query_feat = nn.Embedding(self.num_queries * self.group_detr, hidden_dim) + + self.decoder = RfDetrDecoder(config) + + self.enc_output = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(self.group_detr)]) + self.enc_output_norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(self.group_detr)]) + # Should normally be None and then instantiated in the ForObjectDetection class + self.enc_out_bbox_embed = nn.ModuleList( + [RfDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(self.group_detr)] + ) + self.enc_out_class_embed = nn.ModuleList( + [nn.Linear(config.d_model, config.num_labels) for _ in range(self.group_detr)] + ) + + self.post_init() + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + def get_valid_ratio(self, mask, dtype=torch.float32): + """Get the valid ratio of all feature maps.""" + + _, height, width = mask.shape + valid_height = torch.sum(mask[:, :, 0], 1) + valid_width = torch.sum(mask[:, 0, :], 1) + valid_ratio_height = valid_height.to(dtype) / height + valid_ratio_width = valid_width.to(dtype) / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) + return valid_ratio + + def get_proposal_pos_embed(self, proposals): + """Get the position embedding of the proposals.""" + + num_pos_feats = self.config.d_model // 2 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device) + dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) + # batch_size, num_queries, 4 + proposals = proposals.sigmoid() * scale + # batch_size, num_queries, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): + """Generate the encoder output proposals from encoded enc_output. + + Args: + enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder. + padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`. + spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps. + + Returns: + `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. + - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to + directly predict a bounding box. (without the need of a decoder) + - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse + sigmoid. + """ + batch_size = enc_output.shape[0] + proposals = [] + _cur = 0 + for level, (height, width) in enumerate(spatial_shapes): + mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1) + valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = meshgrid( + torch.linspace( + 0, + height - 1, + height, + dtype=enc_output.dtype, + device=enc_output.device, + ), + torch.linspace( + 0, + width - 1, + width, + dtype=enc_output.dtype, + device=enc_output.device, + ), + indexing="ij", + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale + width_height = torch.ones_like(grid) * 0.05 * (2.0**level) + proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4) + proposals.append(proposal) + _cur += height * width + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + # assign each pixel as an object query + object_query = enc_output + object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0)) + object_query = object_query.masked_fill(~output_proposals_valid, float(0)) + return object_query, output_proposals + + @check_model_inputs + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> RfDetrModelOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, DeformableDetrModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small_60e_coco") + >>> model = DeformableDetrModel.from_pretrained("stevenbucaille/rfdetr_small_60e_coco") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 200, 256] + ```""" + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device) + + # First, retrieve feature maps from backbone + features = self.backbone(pixel_values, pixel_mask) + + sources = [] + masks = [] + for level, (source, mask) in enumerate(features): + sources.append(source) + masks.append(mask) + if mask is None: + raise ValueError("No attention mask was provided") + + # Get initial reference points and query features + if self.training: + reference_points = self.reference_point_embed.weight + query_feat = self.query_feat.weight + else: + # only use first group of reference points and query features during inference + # reference_points (num_queries, 4) : spatial locations of the queries + # query_feat (num_queries, d_model) : features of the queries + reference_points = self.reference_point_embed.weight[: self.num_queries] + query_feat = self.query_feat.weight[: self.num_queries] + + # Prepare decoder inputs (by flattening) + source_flatten = [] + mask_flatten = [] + spatial_shapes_list = [] + for source, mask in zip(sources, masks): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes_list.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + source_flatten.append(source) + mask_flatten.append(mask) + # source_flatten (batch_size, sum(H*W), d_model) : flattened multi-scale feature maps + # mask_flatten (batch_size, sum(H*W)) : flattened mask + # spatial_shapes (num_levels, 2) : spatial shapes of the feature maps + # level_start_index (num_levels,) : start index of each level in source_flatten + # valid_ratios (batch_size, num_levels, 2) : valid ratios of the feature maps + source_flatten = torch.cat(source_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) + + # Duplicate query features and reference points for each image in the batch + target = query_feat.unsqueeze(0).expand(batch_size, -1, -1) + reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) + + # Generate encoder output proposals + object_query_embedding, output_proposals = self.gen_encoder_output_proposals( + source_flatten, ~mask_flatten, spatial_shapes_list + ) + + group_detr = self.group_detr if self.training else 1 + topk = self.num_queries + topk_coords_logits = [] + object_query_undetach = [] + + # Iterate over each group of object queries to refine the object queries + for group_id in range(group_detr): + group_object_query = self.enc_output[group_id](object_query_embedding) + group_object_query = self.enc_output_norm[group_id](group_object_query) + + group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) + group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) + group_enc_outputs_coord = refine_bboxes(output_proposals, group_delta_bbox) + + group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] + group_topk_coords_logits_undetach = torch.gather( + group_enc_outputs_coord, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4), + ) + group_topk_coords_logits = group_topk_coords_logits_undetach.detach() + group_object_query_undetach = torch.gather( + group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.config.d_model) + ) + + topk_coords_logits.append(group_topk_coords_logits) + object_query_undetach.append(group_object_query_undetach) + + # Concatenate the object queries and reference points from all groups + topk_coords_logits = torch.cat(topk_coords_logits, 1) + object_query_undetach = torch.cat(object_query_undetach, 1) + + # Get the class and coordinate logits from the object queries + # enc_outputs_class (batch_size, num_queries, d_model) : object queries + # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of the object queries + enc_outputs_class = object_query_undetach + enc_outputs_coord_logits = topk_coords_logits + + # Refine the reference points using the coordinate logits + two_stage_len = topk_coords_logits.shape[-2] + reference_points_two_stage_subset = reference_points[..., :two_stage_len, :] + reference_points_subset = reference_points[..., two_stage_len:, :] + reference_points_two_stage_subset = refine_bboxes(topk_coords_logits, reference_points_two_stage_subset) + reference_points = torch.cat([reference_points_two_stage_subset, reference_points_subset], dim=-2) + init_reference_points = reference_points + + # Pass the object queries and reference points to the decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + encoder_hidden_states=source_flatten, + encoder_attention_mask=mask_flatten, + **kwargs, + ) + + # init_reference_points (batch_size, num_queries, 4) : initial reference points + # last_hidden_state (batch_size, num_queries, d_model) : final object queries + # intermediate_hidden_states (batch_size, num_decoder_layers, num_queries, d_model) : intermediate object queries + # intermediate_reference_points (batch_size, num_decoder_layers, num_queries, 4) : intermediate reference points + # backbone_features list(batch_size, num_levels, d_model, H, W) : backbone features + # enc_outputs_class (batch_size, num_queries, d_model) : encoder outputs object queries + # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of encoder object queries + return RfDetrModelOutput( + init_reference_points=init_reference_points, + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + backbone_features=sources, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`RfDetrForObjectDetection`]. + """ +) +class RfDetrObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`): + Features from the backbone. + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_boxes: torch.FloatTensor | None = None + auxiliary_outputs: list[dict] | None = None + init_reference_points: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + enc_outputs_class: Any = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + + backbone_features: list[torch.Tensor] = None + + +@auto_docstring( + custom_intro=""" + LW DETR Model (consisting of a backbone and decoder Transformer) with object detection heads on + top, for tasks such as COCO detection. + """ +) +class RfDetrForObjectDetection(RfDetrPreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + _tied_weights_keys = None + + def __init__(self, config: RfDetrConfig): + super().__init__(config) + self.model = RfDetrModel(config) + self.class_embed = nn.Linear(config.d_model, config.num_labels) + self.bbox_embed = RfDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) + + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> RfDetrObjectDetectionOutput: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, LwDetrForObjectDetection + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/lwdetr_small_60e_coco") + >>> model = LwDetrForObjectDetection.from_pretrained("stevenbucaille/lwdetr_small_60e_coco") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[ + ... 0 + ... ] + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78] + Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25] + Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25] + ```""" + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + **kwargs, + ) + + last_hidden_states = outputs.last_hidden_state + intermediate_reference_points = outputs.intermediate_reference_points + enc_outputs_class = outputs.enc_outputs_class + enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits + + # Get logits and boxes from first stage object queries + enc_outputs_class_logits = self.get_encoder_outputs_class_logits(enc_outputs_class) + + # Get logits and boxes from second stage object queries + logits = self.class_embed(last_hidden_states) + pred_boxes_delta = self.bbox_embed(last_hidden_states) + pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate_hidden_states = outputs.intermediate_hidden_states + outputs_coord_delta = self.bbox_embed(intermediate_hidden_states) + outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta) + outputs_class = self.class_embed(intermediate_hidden_states) + + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + ) + + return RfDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_reference_points=outputs.intermediate_reference_points, + init_reference_points=outputs.init_reference_points, + enc_outputs_class=enc_outputs_class_logits, + enc_outputs_coord_logits=enc_outputs_boxes_logits, + backbone_features=outputs.backbone_features, + ) + + def get_encoder_outputs_class_logits(self, enc_outputs_class_logits: torch.Tensor) -> Tensor: + enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.config.num_queries, dim=1) + group_detr = self.config.group_detr if self.training else 1 + pred_class = [ + self.model.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index]) + for group_index in range(group_detr) + ] + enc_outputs_class_logits = torch.cat(pred_class, dim=1) + return enc_outputs_class_logits + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`RfDetrForInstanceSegmentation`]. + """ +) +class RfDetrInstanceSegmentationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`): + Segmentation masks logits for all queries. See also + [`~DetrImageProcessor.post_process_semantic_segmentation`] or + [`~DetrImageProcessor.post_process_instance_segmentation`] + [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic + segmentation masks respectively. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + enc_outputs_mask_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, width, height)`, *optional*): + Mask logits from the encoder for all queries. + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_boxes: torch.FloatTensor | None = None + pred_masks: torch.FloatTensor = None + auxiliary_outputs: list[dict] | None = None + init_reference_points: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + enc_outputs_mask_logits: torch.FloatTensor | None = None + + +class RfDetrSegmentationBlock(nn.Module): + """This corresponds to the `Block` class in the original implementation. + + There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C, + H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back + + The authors used (2) as they find it slightly faster in PyTorch. + + Args: + config ([`RfDetrConfig`]): Model configuration class. + dim (`int`): Number of input channels. + drop_path (`float`): Stochastic depth rate. Default: 0.0. + """ + + def __init__(self, config: RfDetrConfig): + super().__init__() + dim = config.d_model + self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) # depthwise conv + self.layernorm = RfDetrLayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, dim) # pointwise/1x1 convs, implemented with linear layers + self.act = ACT2FN[config.segmentation_head_activation_function] + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + features = self.dwconv(features) + features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + features = self.layernorm(features) + features = self.pwconv1(features) + features = self.act(features) + features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + features = features + residual + return features + + +class RfDetrSegmentationMLPBlock(nn.Module): + def __init__(self, config: RfDetrConfig): + super().__init__() + dim = config.d_model + self.norm_in = nn.LayerNorm(dim) + self.in_linear = nn.Linear(dim, dim * 4) + self.act = ACT2FN[config.segmentation_head_activation_function] + self.out_linear = nn.Linear(dim * 4, dim) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + features = self.norm_in(features) + features = self.in_linear(features) + features = self.act(features) + features = self.out_linear(features) + features = features + residual + return features + + +class RfDetrForInstanceSegmentation(RfDetrPreTrainedModel): + def __init__(self, config: RfDetrConfig): + super().__init__(config) + + self.rf_detr = RfDetrForObjectDetection(config) + + num_blocks = config.decoder_layers + self.downsample_ratio = config.mask_downsample_ratio + self.blocks = nn.ModuleList([RfDetrSegmentationBlock(config) for _ in range(num_blocks)]) + self.spatial_features_proj = nn.Conv2d(config.d_model, config.d_model, kernel_size=1) + + self.query_features_block = RfDetrSegmentationMLPBlock(config) + self.query_features_proj = nn.Linear(config.d_model, config.d_model) + + self.bias = nn.Parameter(torch.zeros(1), requires_grad=True) + + self.post_init() + + def segmentation_head(self, spatial_features, query_features, image_size: torch.Size, skip_blocks: bool = False): + # spatial features: (B, C, H, W) + # query features: [(B, N, C)] for each decoder layer + # output: (B, N, H*r, W*r) + target_size = (image_size[0] // self.downsample_ratio, image_size[1] // self.downsample_ratio) + spatial_features = F.interpolate(spatial_features, size=target_size, mode="bilinear", align_corners=False) + list_mask_logits = [] + if not skip_blocks: + for block, qf in zip(self.blocks, query_features): + spatial_features = block(spatial_features) + spatial_features_proj = self.spatial_features_proj(spatial_features) + qf = self.query_features_block(qf) + qf = self.query_features_proj(qf) + mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf) + mask_logits = mask_logits + self.bias + list_mask_logits.append(mask_logits) + else: + query_features = self.query_features_block(query_features) + query_features = self.query_features_proj(query_features) + mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features, query_features) + mask_logits = mask_logits + self.bias + list_mask_logits.append(mask_logits) + + return list_mask_logits + + @check_model_inputs + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> dict[str, torch.Tensor]: + image_size = pixel_values.shape[-2:] + + outputs = self.rf_detr.model( + pixel_values, + pixel_mask=pixel_mask, + **kwargs, + ) + + spatial_features = outputs.backbone_features[-1] + last_hidden_states = outputs.last_hidden_state + intermediate_reference_points = outputs.intermediate_reference_points + enc_outputs_class = outputs.enc_outputs_class + enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits + query_features = outputs.intermediate_hidden_states + last_hidden_state = outputs.last_hidden_state + + # First stage segmentation proposals + enc_outputs_class_logits = self.rf_detr.get_encoder_outputs_class_logits(enc_outputs_class) + enc_outputs_masks = self.segmentation_head(spatial_features, enc_outputs_class, image_size, skip_blocks=True) + enc_outputs_masks = torch.cat(enc_outputs_masks, dim=1) + + # Second stage segmentation proposals + logits = self.rf_detr.class_embed(last_hidden_states) + pred_boxes_delta = self.rf_detr.bbox_embed(last_hidden_states) + pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + outputs_masks = self.segmentation_head(spatial_features, query_features, image_size) + + pred_masks = outputs_masks[-1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate_hidden_states = outputs.intermediate_hidden_states + outputs_coord_delta = self.rf_detr.bbox_embed(intermediate_hidden_states) + outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta) + outputs_class = self.rf_detr.class_embed(intermediate_hidden_states) + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + pred_masks, + self.config, + outputs_class, + outputs_coord, + outputs_masks, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + enc_outputs_masks, + ) + + return RfDetrInstanceSegmentationOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + pred_masks=pred_masks, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_reference_points=outputs.intermediate_reference_points, + init_reference_points=outputs.init_reference_points, + enc_outputs_mask_logits=enc_outputs_masks, + ) + + +__all__ = [ + "RfDetrModel", + "RfDetrForObjectDetection", + "RfDetrForInstanceSegmentation", + "RfDetrPreTrainedModel", + "RfDetrDinov2Backbone", + "RfDetrDinov2PreTrainedModel", +] diff --git a/src/transformers/models/rf_detr/modular_rf_detr.py b/src/transformers/models/rf_detr/modular_rf_detr.py new file mode 100644 index 000000000000..0d9489c654cb --- /dev/null +++ b/src/transformers/models/rf_detr/modular_rf_detr.py @@ -0,0 +1,1297 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from ...activations import ACT2FN +from ...configuration_utils import PreTrainedConfig +from ...modeling_outputs import BackboneOutput, BaseModelOutput +from ...processing_utils import Unpack +from ...utils import auto_docstring, logging, torch_int +from ...utils.generic import ModelOutput, TransformersKwargs, can_return_tuple, check_model_inputs +from ..auto import CONFIG_MAPPING +from ..convnext.modeling_convnext import ConvNextLayer +from ..dinov2.configuration_dinov2 import Dinov2Config +from ..dinov2.modeling_dinov2 import ( + Dinov2Backbone, + Dinov2Embeddings, + Dinov2Encoder, + Dinov2Layer, + Dinov2PreTrainedModel, + Dinov2SelfAttention, +) +from ..lw_detr.configuration_lw_detr import LwDetrConfig +from ..lw_detr.modeling_lw_detr import ( + LwDetrC2FLayer, + LwDetrConvEncoder, + LwDetrConvNormLayer, + LwDetrForObjectDetection, + LwDetrLayerNorm, + LwDetrModel, + LwDetrModelOutput, + LwDetrObjectDetectionOutput, + LwDetrPreTrainedModel, + LwDetrSamplingLayer, + LwDetrScaleProjector, + refine_bboxes, +) + + +logger = logging.get_logger(__name__) + + +class RfDetrDinov2Config(Dinov2Config): + r""" + This is the configuration class to store the configuration of a [`RfDetrDinov2Model`]. It is used to instantiate an + RfDetrDinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv2 + [facebook/dinov2-base](https://huggingface.co/facebook/dinov2-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + use_mask_token (`bool`, *optional*, defaults to `True`): + Whether to use mask_token in embeddings. + num_windows (`int`, *optional*, defaults to 4): + Number of windows to use for windowed attention. If 1, no windowed attention is used. + Example: + + ```python + >>> from transformers import RfDetrDinov2Config, RfDetrDinov2Backbone + + >>> # Initializing a RfDetrDinov2 base style configuration + >>> configuration = RfDetrDinov2Config() + + >>> # Initializing a model (with random weights) from the base style configuration + >>> model = RfDetrDinov2Backbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rf_detr_dinov2" + + def __init__(self, num_windows: int = 4, **super_kwargs): + super().__init__(**super_kwargs) + + self.num_windows = num_windows + window_block_indexes = set(range(self._out_indices[-1] + 1)) + window_block_indexes.difference_update(self._out_indices) + window_block_indexes = list(window_block_indexes) + self.window_block_indexes = window_block_indexes + + +class RfDetrConfig(LwDetrConfig): + r""" + This is the configuration class to store the configuration of a [`RfDetrModel`]. It is used to instantiate + a LW-DETR model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LW-DETR + [stevenbucaille/RfDetr_small_60e_coco](https://huggingface.co/stevenbucaille/RfDetr_small_60e_coco) architecture. + + LW-DETR (Lightweight Detection Transformer) is a transformer-based object detection model designed for real-time + detection tasks. It replaces traditional CNN-based detectors like YOLO with a more efficient transformer architecture + that achieves competitive performance while being computationally lightweight. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. If not provided, will default to `RfDetrDinov2Config` + with a small ViT architecture optimized for detection tasks. + projector_scale_factors (`list[float]`, *optional*, defaults to `[]`): + Scale factors for the feature pyramid network. Each scale factor determines the resolution of features + at different levels. Supported values are 0.5, 1.0, and 2.0. + hidden_expansion (`float`, *optional*, defaults to 0.5): + Expansion factor for hidden dimensions in the projector layers. + c2f_num_blocks (`int`, *optional*, defaults to 3): + Number of blocks in the C2F layer. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the projector. Supported values are `"silu"`, `"relu"`, `"gelu"`. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for layer normalization layers. + d_model (`int`, *optional*, defaults to 256): + Dimension of the model layers and the number of expected features in the decoder inputs. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 3): + Number of decoder layers in the transformer. + decoder_self_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the decoder self-attention. + decoder_cross_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the decoder cross-attention. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the decoder. Supported values are `"relu"`, `"silu"`, `"gelu"`. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`RfDetrModel`] can detect in a single image. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to add bias to the attention layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + group_detr (`int`, *optional*, defaults to 13): + Number of groups for Group DETR attention mechanism, which helps reduce computational complexity. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + disable_custom_kernels (`bool`, *optional*, defaults to `True`): + Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom + kernels are not supported by PyTorch ONNX export. + class_cost (`float`, *optional*, defaults to 2): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + class_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the classification loss in the Hungarian matching cost. + mask_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Focal loss in the instance segmentation mask loss. + dice_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the DICE/F-1 loss in the object detection loss. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + focal_alpha (`float`, *optional*, defaults to 0.25): + Alpha parameter in the focal loss. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + mask_point_sample_ratio (`int`, *optional*, defaults to 16): + The ratio of points to sample for the mask loss calculation. + mask_downsample_ratio (`int`, *optional*, defaults to 4): + The downsample ratio for the segmentation masks compared to the input image resolution. + mask_class_loss_coefficient (`float`, *optional*, defaults to 5.0): + Relative weight of the Focal loss in the instance segmentation loss. + mask_dice_loss_coefficient (`float`, *optional*, defaults to 5.0): + Relative weight of the DICE/F-1 loss in the instance segmentation loss. + segmentation_head_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the segmentation head. Supported values are `"relu"`, `"silu"`, `"gelu"`. + Examples: + + ```python + >>> from transformers import RfDetrConfig, RfDetrModel + + >>> # Initializing a LW-DETR stevenbucaille/RfDetr_small_60e_coco style configuration + >>> configuration = RfDetrConfig() + + >>> # Initializing a model (with random weights) from the stevenbucaille/RfDetr_small_60e_coco style configuration + >>> model = RfDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rf_detr" + + def __init__( + self, + # backbone + backbone_config=None, + # projector + projector_scale_factors: list[float] = [], + hidden_expansion=0.5, + c2f_num_blocks=3, + activation_function="silu", + layer_norm_eps=1e-5, + # decoder + d_model=256, + dropout=0.1, + decoder_ffn_dim=2048, + decoder_n_points=4, + decoder_layers: int = 3, + decoder_self_attention_heads: int = 8, + decoder_cross_attention_heads: int = 16, + decoder_activation_function="relu", + # model + num_queries=300, + attention_bias=True, + attention_dropout=0.0, + activation_dropout=0.0, + group_detr: int = 13, + init_std=0.02, + disable_custom_kernels=True, + # loss + class_cost=2, + bbox_cost=5, + giou_cost=2, + class_loss_coefficient=1, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.1, + focal_alpha=0.25, + auxiliary_loss=True, + mask_point_sample_ratio=16, + # segmentation + mask_downsample_ratio=4, + mask_class_loss_coefficient=5.0, + mask_dice_loss_coefficient=5.0, + segmentation_head_activation_function="gelu", + **kwargs, + ): + self.layer_norm_eps = layer_norm_eps + + # backbone + if backbone_config is None: + logger.info( + "`backbone_config` is `None`. Initializing the config with the default `RfDetrDinov2` backbone." + ) + backbone_config = RfDetrDinov2Config( + attention_probs_dropout_prob=0.0, + drop_path_rate=0.0, + hidden_act="gelu", + hidden_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-06, + layerscale_value=1.0, + mlp_ratio=4, + num_attention_heads=6, + num_channels=3, + num_hidden_layers=12, + qkv_bias=True, + use_swiglu_ffn=False, + out_features=["stage2", "stage5", "stage8", "stage11"], + hidden_size=384, + patch_size=14, + num_windows=4, + num_register_tokens=0, + image_size=518, + **kwargs, + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + self.backbone_config = backbone_config + + # projector + self.projector_scale_factors = projector_scale_factors + for scale in projector_scale_factors: + if scale not in [0.5, 1.0, 2.0]: + raise ValueError(f"Unsupported scale factor: {scale}") + self.projector_in_channels = [d_model] * len(projector_scale_factors) + self.projector_out_channels = d_model + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + self.c2f_num_blocks = c2f_num_blocks + # decoder + self.d_model = d_model + self.dropout = dropout + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = len(self.projector_scale_factors) + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_activation_function = decoder_activation_function + self.decoder_self_attention_heads = decoder_self_attention_heads + self.decoder_cross_attention_heads = decoder_cross_attention_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + # model + self.init_std = init_std + self.group_detr = group_detr + # Loss + self.auxiliary_loss = auxiliary_loss + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.class_loss_coefficient = class_loss_coefficient + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.mask_class_loss_coefficient = mask_class_loss_coefficient + self.mask_dice_loss_coefficient = mask_dice_loss_coefficient + self.eos_coefficient = eos_coefficient + self.focal_alpha = focal_alpha + self.disable_custom_kernels = disable_custom_kernels + self.mask_point_sample_ratio = mask_point_sample_ratio + # segmentation + self.mask_downsample_ratio = mask_downsample_ratio + self.segmentation_head_activation_function = segmentation_head_activation_function + PreTrainedConfig.__init__(self, **kwargs) + + +def window_partition( + embeddings: torch.Tensor, num_windows: int, patch_size: int, height: int, width: int +) -> torch.Tensor: + batch_size = embeddings.shape[0] + num_h_patches = height // patch_size + num_w_patches = width // patch_size + cls_token_with_pos_embed = embeddings[:, :1] + pixel_tokens_with_pos_embed = embeddings[:, 1:] + pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(batch_size, num_h_patches, num_w_patches, -1) + num_w_patches_per_window = num_w_patches // num_windows + num_h_patches_per_window = num_h_patches // num_windows + windowed_pixel_tokens = pixel_tokens_with_pos_embed.view( + batch_size, num_windows, num_h_patches_per_window, num_windows, num_h_patches_per_window, -1 + ) + windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 1, 3, 2, 4, 5) + windowed_pixel_tokens = windowed_pixel_tokens.reshape( + batch_size * num_windows**2, num_h_patches_per_window * num_w_patches_per_window, -1 + ) + windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows**2, 1, 1) + embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1) + return embeddings + + +class RfDetrDinov2Embeddings(Dinov2Embeddings): + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(torch.float32), + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + antialias=True, + ).to(dtype=target_dtype) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + if self.config.num_windows > 1: + # reshape for windows + embeddings = window_partition(embeddings, self.config.num_windows, self.config.patch_size, height, width) + embeddings = self.dropout(embeddings) + + return embeddings + + +def window_unpartition_before_attention(hidden_states: torch.Tensor, num_windows: int) -> torch.Tensor: + batch_size, seq_len, channels = hidden_states.shape + num_windows_squared = num_windows**2 + hidden_states = hidden_states.view(batch_size // num_windows_squared, num_windows_squared * seq_len, channels) + return hidden_states + + +def window_partition_after_attention( + hidden_states: torch.Tensor, self_attention_output: torch.Tensor, num_windows: int +) -> torch.Tensor: + batch_size, seq_len, channels = hidden_states.shape + num_windows_squared = num_windows**2 + self_attention_output = self_attention_output.view( + batch_size * num_windows_squared, seq_len // num_windows_squared, channels + ) + return self_attention_output + + +class RfDetrDinov2SelfAttention(Dinov2SelfAttention): + def __init__(self, config: RfDetrDinov2Config): + super().__init__(config) + self.num_key_value_groups = 1 + + +class RfDetrDinov2Layer(Dinov2Layer): + def __init__(self, config: RfDetrDinov2Config, layer_idx: int): + super().__init__(config) + self.num_windows = config.num_windows + self.global_attention = layer_idx not in config.window_block_indexes + + def forward( + self, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]: + shortcut = hidden_states + if self.global_attention: + hidden_states = window_unpartition_before_attention(hidden_states, self.num_windows) + + hidden_states_norm = self.norm1(hidden_states) + self_attention_output = self.attention(hidden_states_norm) + + if self.global_attention: + self_attention_output = window_partition_after_attention( + hidden_states, self_attention_output, self.num_windows + ) + + self_attention_output = self.layer_scale1(self_attention_output) + + # first residual connection + hidden_states = self.drop_path(self_attention_output) + shortcut + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + return layer_output + + +class RfDetrDinov2Encoder(Dinov2Encoder): + def __init__(self, config: RfDetrDinov2Config): + super().__init__(config) + self.layer = nn.ModuleList([RfDetrDinov2Layer(config, i) for i in range(config.num_hidden_layers)]) + + +class RfDetrDinov2PreTrainedModel(Dinov2PreTrainedModel): + pass + + +def window_unpartition( + hidden_state: torch.Tensor, + num_windows: int, + num_h_patches: int, + num_w_patches: int, +) -> torch.Tensor: + hidden_batch_size, seq_len, channels = hidden_state.shape + num_windows_squared = num_windows**2 + num_h_patches_per_window = num_h_patches // num_windows + num_w_patches_per_window = num_w_patches // num_windows + hidden_state = hidden_state.reshape( + hidden_batch_size // num_windows_squared, num_windows_squared * seq_len, channels + ) + hidden_state = hidden_state.view( + hidden_batch_size // num_windows_squared, + num_windows, + num_windows, + num_h_patches_per_window, + num_w_patches_per_window, + channels, + ) + hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5) + return hidden_state + + +class RfDetrDinov2Backbone(Dinov2Backbone): + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + embedding_output = self.embeddings(pixel_values) + + output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True) + hidden_states = output.hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + + num_h_patches = height // patch_size + num_w_patches = width // patch_size + + if self.config.num_windows > 1: + hidden_state = window_unpartition( + hidden_state, self.config.num_windows, num_h_patches, num_w_patches + ) + + hidden_state = hidden_state.reshape(batch_size, num_h_patches, num_w_patches, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + + feature_maps += (hidden_state,) + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=hidden_states if output_hidden_states else None, + ) + + +class RfDetrLayerNorm(LwDetrLayerNorm): + pass + + +class RfDetrConvNormLayer(LwDetrConvNormLayer): + def __init__( + self, + config: RfDetrConfig, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + activation: str | None = None, + ): + super().__init__( + config, + in_channels, + out_channels, + kernel_size, + stride, + activation, + ) + self.norm = RfDetrLayerNorm(out_channels, data_format="channels_first", eps=config.layer_norm_eps) + + +class RfDetrC2FLayer(LwDetrC2FLayer): + pass + + +class RfDetrSamplingLayer(LwDetrSamplingLayer): + def __init__(self, config: RfDetrConfig, channel_size: int, scale: float): + nn.Module.__init__(self) + + self.scale = scale + self.channel_size = channel_size + + layers = [] + if scale == 2.0: + layers.append(nn.ConvTranspose2d(channel_size, channel_size // 2, 2, 2)) + elif scale == 0.5: + layers.append(RfDetrConvNormLayer(config, channel_size, channel_size, 3, 2, activation="relu")) + self.layers = nn.ModuleList(layers) + + +class RfDetrScaleProjector(LwDetrScaleProjector): + def __init__(self, config: RfDetrConfig, scale: float): + nn.Module.__init__(self) + + intermediate_dims = [config.backbone_config.hidden_size] * len(config.backbone_config.out_indices) + sampling_layers = [] + for channel_size in intermediate_dims: + sampling_layers.append(RfDetrSamplingLayer(config, channel_size, scale)) + self.sampling_layers = nn.ModuleList(sampling_layers) + + intermediate_dim = intermediate_dims[-1] + if scale == 2.0: + intermediate_dim = intermediate_dim // 2 + projector_input_dim = intermediate_dim * len(intermediate_dims) + + self.projector_layer = RfDetrC2FLayer(config, projector_input_dim) + self.layer_norm = RfDetrLayerNorm(config.d_model, data_format="channels_first") + + +class RfDetrConvEncoder(LwDetrConvEncoder): + def __init__(self, config: RfDetrConfig): + super().__init__(config) + self.backbone = RfDetrDinov2Backbone(config.backbone_config) + + +class RfDetrPreTrainedModel(LwDetrPreTrainedModel): + pass + + +class RfDetrModelOutput(LwDetrModelOutput): + r""" + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`): + Features from the backbone. + """ + + backbone_features: list[torch.Tensor] = None + + +class RfDetrModel(LwDetrModel): + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> RfDetrModelOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, DeformableDetrModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/rfdetr_small_60e_coco") + >>> model = DeformableDetrModel.from_pretrained("stevenbucaille/rfdetr_small_60e_coco") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 200, 256] + ```""" + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device) + + # First, retrieve feature maps from backbone + features = self.backbone(pixel_values, pixel_mask) + + sources = [] + masks = [] + for level, (source, mask) in enumerate(features): + sources.append(source) + masks.append(mask) + if mask is None: + raise ValueError("No attention mask was provided") + + # Get initial reference points and query features + if self.training: + reference_points = self.reference_point_embed.weight + query_feat = self.query_feat.weight + else: + # only use first group of reference points and query features during inference + # reference_points (num_queries, 4) : spatial locations of the queries + # query_feat (num_queries, d_model) : features of the queries + reference_points = self.reference_point_embed.weight[: self.num_queries] + query_feat = self.query_feat.weight[: self.num_queries] + + # Prepare decoder inputs (by flattening) + source_flatten = [] + mask_flatten = [] + spatial_shapes_list = [] + for source, mask in zip(sources, masks): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes_list.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + source_flatten.append(source) + mask_flatten.append(mask) + # source_flatten (batch_size, sum(H*W), d_model) : flattened multi-scale feature maps + # mask_flatten (batch_size, sum(H*W)) : flattened mask + # spatial_shapes (num_levels, 2) : spatial shapes of the feature maps + # level_start_index (num_levels,) : start index of each level in source_flatten + # valid_ratios (batch_size, num_levels, 2) : valid ratios of the feature maps + source_flatten = torch.cat(source_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) + + # Duplicate query features and reference points for each image in the batch + target = query_feat.unsqueeze(0).expand(batch_size, -1, -1) + reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) + + # Generate encoder output proposals + object_query_embedding, output_proposals = self.gen_encoder_output_proposals( + source_flatten, ~mask_flatten, spatial_shapes_list + ) + + group_detr = self.group_detr if self.training else 1 + topk = self.num_queries + topk_coords_logits = [] + object_query_undetach = [] + + # Iterate over each group of object queries to refine the object queries + for group_id in range(group_detr): + group_object_query = self.enc_output[group_id](object_query_embedding) + group_object_query = self.enc_output_norm[group_id](group_object_query) + + group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) + group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) + group_enc_outputs_coord = refine_bboxes(output_proposals, group_delta_bbox) + + group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] + group_topk_coords_logits_undetach = torch.gather( + group_enc_outputs_coord, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4), + ) + group_topk_coords_logits = group_topk_coords_logits_undetach.detach() + group_object_query_undetach = torch.gather( + group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.config.d_model) + ) + + topk_coords_logits.append(group_topk_coords_logits) + object_query_undetach.append(group_object_query_undetach) + + # Concatenate the object queries and reference points from all groups + topk_coords_logits = torch.cat(topk_coords_logits, 1) + object_query_undetach = torch.cat(object_query_undetach, 1) + + # Get the class and coordinate logits from the object queries + # enc_outputs_class (batch_size, num_queries, d_model) : object queries + # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of the object queries + enc_outputs_class = object_query_undetach + enc_outputs_coord_logits = topk_coords_logits + + # Refine the reference points using the coordinate logits + two_stage_len = topk_coords_logits.shape[-2] + reference_points_two_stage_subset = reference_points[..., :two_stage_len, :] + reference_points_subset = reference_points[..., two_stage_len:, :] + reference_points_two_stage_subset = refine_bboxes(topk_coords_logits, reference_points_two_stage_subset) + reference_points = torch.cat([reference_points_two_stage_subset, reference_points_subset], dim=-2) + init_reference_points = reference_points + + # Pass the object queries and reference points to the decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + encoder_hidden_states=source_flatten, + encoder_attention_mask=mask_flatten, + **kwargs, + ) + + # init_reference_points (batch_size, num_queries, 4) : initial reference points + # last_hidden_state (batch_size, num_queries, d_model) : final object queries + # intermediate_hidden_states (batch_size, num_decoder_layers, num_queries, d_model) : intermediate object queries + # intermediate_reference_points (batch_size, num_decoder_layers, num_queries, 4) : intermediate reference points + # backbone_features list(batch_size, num_levels, d_model, H, W) : backbone features + # enc_outputs_class (batch_size, num_queries, d_model) : encoder outputs object queries + # enc_outputs_coord_logits (batch_size, num_queries, 4) : coordinate logits of encoder object queries + return RfDetrModelOutput( + init_reference_points=init_reference_points, + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + backbone_features=sources, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + ) + + +class RfDetrObjectDetectionOutput(LwDetrObjectDetectionOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + backbone_features (list of `torch.FloatTensor` of shape `(batch_size, config.num_channels, config.image_size, config.image_size)`): + Features from the backbone. + """ + + backbone_features: list[torch.Tensor] = None + + +class RfDetrForObjectDetection(LwDetrForObjectDetection): + def get_encoder_outputs_class_logits(self, enc_outputs_class_logits: torch.Tensor) -> Tensor: + enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.config.num_queries, dim=1) + group_detr = self.config.group_detr if self.training else 1 + pred_class = [ + self.model.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index]) + for group_index in range(group_detr) + ] + enc_outputs_class_logits = torch.cat(pred_class, dim=1) + return enc_outputs_class_logits + + @check_model_inputs + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> RfDetrObjectDetectionOutput: + r""" + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, LwDetrForObjectDetection + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("stevenbucaille/lwdetr_small_60e_coco") + >>> model = LwDetrForObjectDetection.from_pretrained("stevenbucaille/lwdetr_small_60e_coco") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[ + ... 0 + ... ] + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78] + Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25] + Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25] + ```""" + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + **kwargs, + ) + + last_hidden_states = outputs.last_hidden_state + intermediate_reference_points = outputs.intermediate_reference_points + enc_outputs_class = outputs.enc_outputs_class + enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits + + # Get logits and boxes from first stage object queries + enc_outputs_class_logits = self.get_encoder_outputs_class_logits(enc_outputs_class) + + # Get logits and boxes from second stage object queries + logits = self.class_embed(last_hidden_states) + pred_boxes_delta = self.bbox_embed(last_hidden_states) + pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate_hidden_states = outputs.intermediate_hidden_states + outputs_coord_delta = self.bbox_embed(intermediate_hidden_states) + outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta) + outputs_class = self.class_embed(intermediate_hidden_states) + + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + ) + + return RfDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_reference_points=outputs.intermediate_reference_points, + init_reference_points=outputs.init_reference_points, + enc_outputs_class=enc_outputs_class_logits, + enc_outputs_coord_logits=enc_outputs_boxes_logits, + backbone_features=outputs.backbone_features, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`RfDetrForInstanceSegmentation`]. + """ +) +class RfDetrInstanceSegmentationOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`): + Segmentation masks logits for all queries. See also + [`~DetrImageProcessor.post_process_semantic_segmentation`] or + [`~DetrImageProcessor.post_process_instance_segmentation`] + [`~DetrImageProcessor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic + segmentation masks respectively. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + enc_outputs_mask_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, width, height)`, *optional*): + Mask logits from the encoder for all queries. + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_boxes: torch.FloatTensor | None = None + pred_masks: torch.FloatTensor = None + auxiliary_outputs: list[dict] | None = None + init_reference_points: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + enc_outputs_mask_logits: torch.FloatTensor | None = None + + +class RfDetrSegmentationBlock(ConvNextLayer): + def __init__(self, config: RfDetrConfig): + dim = config.d_model + super().__init__(config) + self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) # depthwise conv + self.layernorm = RfDetrLayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, dim) # pointwise/1x1 convs, implemented with linear layers + self.act = ACT2FN[config.segmentation_head_activation_function] + del self.pwconv2 + del self.layer_scale_parameter + del self.drop_path + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + features = self.dwconv(features) + features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + features = self.layernorm(features) + features = self.pwconv1(features) + features = self.act(features) + features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + features = features + residual + return features + + +class RfDetrSegmentationMLPBlock(nn.Module): + def __init__(self, config: RfDetrConfig): + super().__init__() + dim = config.d_model + self.norm_in = nn.LayerNorm(dim) + self.in_linear = nn.Linear(dim, dim * 4) + self.act = ACT2FN[config.segmentation_head_activation_function] + self.out_linear = nn.Linear(dim * 4, dim) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + features = self.norm_in(features) + features = self.in_linear(features) + features = self.act(features) + features = self.out_linear(features) + features = features + residual + return features + + +class RfDetrForInstanceSegmentation(RfDetrPreTrainedModel): + def __init__(self, config: RfDetrConfig): + super().__init__(config) + + self.rf_detr = RfDetrForObjectDetection(config) + + num_blocks = config.decoder_layers + self.downsample_ratio = config.mask_downsample_ratio + self.blocks = nn.ModuleList([RfDetrSegmentationBlock(config) for _ in range(num_blocks)]) + self.spatial_features_proj = nn.Conv2d(config.d_model, config.d_model, kernel_size=1) + + self.query_features_block = RfDetrSegmentationMLPBlock(config) + self.query_features_proj = nn.Linear(config.d_model, config.d_model) + + self.bias = nn.Parameter(torch.zeros(1), requires_grad=True) + + self.post_init() + + def segmentation_head(self, spatial_features, query_features, image_size: torch.Size, skip_blocks: bool = False): + # spatial features: (B, C, H, W) + # query features: [(B, N, C)] for each decoder layer + # output: (B, N, H*r, W*r) + target_size = (image_size[0] // self.downsample_ratio, image_size[1] // self.downsample_ratio) + spatial_features = F.interpolate(spatial_features, size=target_size, mode="bilinear", align_corners=False) + list_mask_logits = [] + if not skip_blocks: + for block, qf in zip(self.blocks, query_features): + spatial_features = block(spatial_features) + spatial_features_proj = self.spatial_features_proj(spatial_features) + qf = self.query_features_block(qf) + qf = self.query_features_proj(qf) + mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf) + mask_logits = mask_logits + self.bias + list_mask_logits.append(mask_logits) + else: + query_features = self.query_features_block(query_features) + query_features = self.query_features_proj(query_features) + mask_logits = torch.einsum("bchw,bnc->bnhw", spatial_features, query_features) + mask_logits = mask_logits + self.bias + list_mask_logits.append(mask_logits) + + return list_mask_logits + + @check_model_inputs + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> dict[str, torch.Tensor]: + image_size = pixel_values.shape[-2:] + + outputs = self.rf_detr.model( + pixel_values, + pixel_mask=pixel_mask, + **kwargs, + ) + + spatial_features = outputs.backbone_features[-1] + last_hidden_states = outputs.last_hidden_state + intermediate_reference_points = outputs.intermediate_reference_points + enc_outputs_class = outputs.enc_outputs_class + enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits + query_features = outputs.intermediate_hidden_states + last_hidden_state = outputs.last_hidden_state + + # First stage segmentation proposals + enc_outputs_class_logits = self.rf_detr.get_encoder_outputs_class_logits(enc_outputs_class) + enc_outputs_masks = self.segmentation_head(spatial_features, enc_outputs_class, image_size, skip_blocks=True) + enc_outputs_masks = torch.cat(enc_outputs_masks, dim=1) + + # Second stage segmentation proposals + logits = self.rf_detr.class_embed(last_hidden_states) + pred_boxes_delta = self.rf_detr.bbox_embed(last_hidden_states) + pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + outputs_masks = self.segmentation_head(spatial_features, query_features, image_size) + + pred_masks = outputs_masks[-1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate_hidden_states = outputs.intermediate_hidden_states + outputs_coord_delta = self.rf_detr.bbox_embed(intermediate_hidden_states) + outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta) + outputs_class = self.rf_detr.class_embed(intermediate_hidden_states) + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + pred_masks, + self.config, + outputs_class, + outputs_coord, + outputs_masks, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + enc_outputs_masks, + ) + + return RfDetrInstanceSegmentationOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + pred_masks=pred_masks, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_reference_points=outputs.intermediate_reference_points, + init_reference_points=outputs.init_reference_points, + enc_outputs_mask_logits=enc_outputs_masks, + ) + + +__all__ = [ + "RfDetrConfig", + "RfDetrModel", + "RfDetrForObjectDetection", + "RfDetrForInstanceSegmentation", + "RfDetrPreTrainedModel", + "RfDetrDinov2Config", + "RfDetrDinov2Backbone", + "RfDetrDinov2PreTrainedModel", +] diff --git a/tests/models/rf_detr/__init__.py b/tests/models/rf_detr/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/rf_detr/test_modeling_rf_detr.py b/tests/models/rf_detr/test_modeling_rf_detr.py new file mode 100644 index 000000000000..b6c977fc37de --- /dev/null +++ b/tests/models/rf_detr/test_modeling_rf_detr.py @@ -0,0 +1,863 @@ +# coding = utf-8 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest +from functools import cached_property + +import numpy as np + +from transformers import ( + CONFIG_NAME, + DetrImageProcessor, + RfDetrConfig, + RfDetrDinov2Config, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + Expectations, + require_torch, + require_vision, + slow, + torch_device, +) + +from ...test_backbone_common import BackboneTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import RfDetrDinov2Backbone, RfDetrForInstanceSegmentation, RfDetrForObjectDetection, RfDetrModel + +if is_vision_available(): + from PIL import Image + +CHECKPOINT = { + "base": "stevenbucaille/rf-detr-base", + "large": "stevenbucaille/rf-detr-large", + "segmentation": "stevenbucaille/rf-detr-segmentation", +} + + +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +class RfDetrModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + image_size=256, + num_labels=5, + n_targets=4, + use_labels=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + # projector + projector_scale_factors=[0.5, 2.0], + # decoder + d_model=32, + decoder_ffn_dim=32, + decoder_layers=2, + decoder_self_attention_heads=2, + decoder_cross_attention_heads=4, + # model + num_queries=10, + group_detr=2, + dropout=0.0, + activation_dropout=0.0, + attention_dropout=0.0, + attn_implementation="eager", + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.num_channels = 3 + self.image_size = image_size + self.num_labels = num_labels + self.n_targets = n_targets + self.use_labels = use_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.backbone_config = backbone_config + self.projector_scale_factors = projector_scale_factors + self.d_model = d_model + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_self_attention_heads = decoder_self_attention_heads + self.decoder_cross_attention_heads = decoder_cross_attention_heads + self.num_queries = num_queries + self.group_detr = group_detr + self.dropout = dropout + self.activation_dropout = activation_dropout + self.attention_dropout = attention_dropout + self.attn_implementation = attn_implementation + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + backbone_config = RfDetrDinov2Config( + attention_probs_dropout_prob=0.0, + drop_path_rate=0.0, + hidden_act="gelu", + hidden_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-06, + layerscale_value=1.0, + mlp_ratio=4, + num_attention_heads=2, + num_channels=3, + num_hidden_layers=4, + qkv_bias=True, + use_swiglu_ffn=False, + out_features=["stage2", "stage3"], + hidden_size=self.d_model, + patch_size=16, + num_windows=2, + image_size=self.image_size, + attn_implementation=self.attn_implementation, + ) + return RfDetrConfig( + backbone_config=backbone_config, + d_model=self.d_model, + projector_scale_factors=self.projector_scale_factors, + decoder_ffn_dim=self.decoder_ffn_dim, + decoder_layers=self.decoder_layers, + decoder_self_attention_heads=self.decoder_self_attention_heads, + decoder_cross_attention_heads=self.decoder_cross_attention_heads, + num_queries=self.num_queries, + group_detr=self.group_detr, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + attention_dropout=self.attention_dropout, + attn_implementation=self.attn_implementation, + _attn_implementation=self.attn_implementation, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} + return config, inputs_dict + + def create_and_check_rf_detr_model(self, config, pixel_values, pixel_mask, labels): + model = RfDetrModel(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_rf_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = RfDetrForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class RfDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + (RfDetrModel, RfDetrForObjectDetection, RfDetrForInstanceSegmentation) if is_torch_available() else () + ) + pipeline_model_mapping = ( + { + "image-feature-extraction": RfDetrModel, + "object-detection": RfDetrForObjectDetection, + "instance-segmentation": RfDetrForInstanceSegmentation, + } + if is_torch_available() + else {} + ) + is_encoder_decoder = False + test_missing_keys = False + test_torch_exportable = True + + # special case for head models + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ in ["RfDetrForObjectDetection", "RfDetrForInstanceSegmentation"]: + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + target["masks"] = torch.ones( + self.model_tester.n_targets, + self.model_tester.image_size, + self.model_tester.image_size, + device=torch_device, + dtype=torch.float, + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = RfDetrModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=RfDetrConfig, + has_text_modality=False, + common_properties=["d_model", "decoder_self_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_rf_detr_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_rf_detr_model(*config_and_inputs) + + def test_rf_detr_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_rf_detr_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="RTDetr does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="RTDetr does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="RTDetr does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="RTDetr does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="RTDetr does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + def test_attention_outputs(self): + def check_attention_outputs(inputs_dict, config, model_class): + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.decoder_layers) + expected_attentions_shape = [ + self.model_tester.batch_size, + self.model_tester.decoder_self_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ] + for i in range(self.model_tester.decoder_layers): + self.assertEqual(expected_attentions_shape, list(attentions[i].shape)) + + # check cross_attentions outputs + expected_attentions_shape = [ + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.decoder_cross_attention_heads, + config.num_feature_levels, + config.decoder_n_points, + ] + cross_attentions = outputs.cross_attentions + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + for i in range(self.model_tester.decoder_layers): + self.assertEqual(expected_attentions_shape, list(cross_attentions[i].shape)) + + out_len = len(outputs) + + if model_class.__name__ == "RfDetrModel": + correct_outlen = 9 # 7 + attentions + cross_attentions + if model_class.__name__ in "RfDetrForObjectDetection": + correct_outlen = 11 # 9 + attentions + cross_attentions + if "labels" in inputs_dict: + correct_outlen += 3 # loss, loss_dict and auxiliary outputs is added to beginning + elif model_class.__name__ == "RfDetrForInstanceSegmentation": + correct_outlen = 10 # 11 + attentions + cross_attentions + if "labels" in inputs_dict: + correct_outlen += 3 # loss, loss_dict and auxiliary outputs is added to beginning + + self.assertEqual(correct_outlen, out_len) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + inputs_dict["output_hidden_states"] = False + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + check_attention_outputs(inputs_dict, config, model_class) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + check_attention_outputs(inputs_dict, config, model_class) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_hidden_states = self.model_tester.decoder_layers + 1 + self.assertEqual(len(hidden_states), expected_num_hidden_states) + + for i in range(expected_num_hidden_states): + self.assertListEqual( + list(hidden_states[i].shape), + [ + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.d_model, + ], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = False + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + check_hidden_states_output(inputs_dict, config, model_class) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + print("Model class:", model_class) + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if ( + "level_embed" in name + or "sampling_offsets.bias" in name + or "value_proj" in name + or "output_proj" in name + or "reference_points" in name + or "class_embed" in name + or "gamma_1" in name + or "gamma_2" in name + ): + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + # we take the first output since last_hidden_state is the first item + output = outputs.last_hidden_state + + hidden_states = outputs.hidden_states[0] + attentions = outputs.attentions[0] + hidden_states.retain_grad() + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + self.assertIsNotNone(attentions.grad) + + def test_save_load(self): + def check_save_load(out1, out2): + # make sure we don't have nans + out_2 = out2.cpu().numpy() + out_2[np.isnan(out_2)] = 0 + out_2 = out_2[~np.isneginf(out_2)] + + out_1 = out1.cpu().numpy() + out_1[np.isnan(out_1)] = 0 + out_1 = out_1[~np.isneginf(out_1)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + first = model(**self._prepare_for_class(inputs_dict, model_class))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # the config file (and the generation config file, if it can generate) should be saved + self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME))) + + model = model_class.from_pretrained(tmpdirname) + model.config._attn_implementation = "eager" # TODO Have to force eager for testing, why ? + model.to(torch_device) + with torch.no_grad(): + second = model(**self._prepare_for_class(inputs_dict, model_class))[0] + + # Save and load second time because `from_pretrained` adds a bunch of new config fields + # so we need to make sure those fields can be loaded back after saving + # Simply init as `model(config)` doesn't add those fields + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + + if isinstance(first, tuple) and isinstance(second, tuple): + for tensor1, tensor2 in zip(first, second): + check_save_load(tensor1, tensor2) + else: + check_save_load(first, second) + + def test_forward_auxiliary_loss(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.auxiliary_loss = True + + # only test for object detection and segmentation model + for model_class in self.all_model_classes[1:]: + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + outputs = model(**inputs) + + self.assertIsNotNone(outputs.auxiliary_outputs) + self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.decoder_layers - 1) + + +@require_torch +@require_vision +class RfDetrModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + if is_vision_available(): + return { + "base": DetrImageProcessor.from_pretrained(CHECKPOINT["base"]), + "large": DetrImageProcessor.from_pretrained(CHECKPOINT["large"]), + "segmentation": DetrImageProcessor.from_pretrained(CHECKPOINT["segmentation"]), + } + + @slow + def test_inference_object_detection_head_base(self): + size = "base" + model = RfDetrForObjectDetection.from_pretrained(CHECKPOINT[size], attn_implementation="eager").to( + torch_device + ) + + image_processor = self.default_image_processor[size] + image = prepare_img() + encoding = image_processor(images=image, return_tensors="pt").to(torch_device) + pixel_values = encoding["pixel_values"].to(torch_device) + pixel_mask = encoding["pixel_mask"].to(torch_device) + with torch.no_grad(): + outputs = model(pixel_values, pixel_mask) + + expected_logits_shape = torch.Size((1, model.config.num_queries, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_logits_shape) + + expectations = Expectations( + { + ("cpu", None): [-7.60410, -4.65943, -10.03144, -5.63881, -9.88291], + } + ) + expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits.flatten()[:5], expected_logits, rtol=2e-4, atol=2e-4) + + expected_boxes_shape = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_boxes_shape) + + expectations = Expectations( + { + ("cpu", None): [0.25465, 0.54864, 0.48583, 0.86991, 0.16926], + } + ) + expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(outputs.pred_boxes.flatten()[:5], expected_boxes, rtol=2e-4, atol=2e-4) + + results = image_processor.post_process_object_detection( + outputs, threshold=0.0, target_sizes=[image.size[::-1]] + )[0] + + expectations = Expectations( + { + ("cpu", None): [0.9829, 0.9763, 0.9780, 0.8663], + } + ) + expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device) + + expected_labels = [17, 75, 17, 75] + + expectations = Expectations( + { + ("cpu", None): [ + [7.5101, 54.5667, 318.4421, 472.1259], + [40.7178, 72.6109, 175.9414, 117.5903], + [343.0202, 23.9165, 639.3325, 372.2062], + [333.5370, 76.9845, 370.3848, 187.3158], + ], + } + ) + expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(results["scores"][:4], expected_scores, atol=1e-3, rtol=2e-4) + self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels) + torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=2e-4) + + @slow + def test_inference_object_detection_head_large(self): + size = "large" + model = RfDetrForObjectDetection.from_pretrained(CHECKPOINT[size], attn_implementation="eager").to( + torch_device + ) + + image_processor = self.default_image_processor[size] + image = prepare_img() + encoding = image_processor(images=image, return_tensors="pt").to(torch_device) + pixel_values = encoding["pixel_values"].to(torch_device) + pixel_mask = encoding["pixel_mask"].to(torch_device) + with torch.no_grad(): + outputs = model(pixel_values, pixel_mask) + + expected_logits_shape = torch.Size((1, model.config.num_queries, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_logits_shape) + + expectations = Expectations( + { + ("cpu", None): [-7.60888, -4.36906, -4.98865, -8.06598, -5.52970], + } + ) + expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits.flatten()[:5], expected_logits, rtol=2e-3, atol=2e-3) + + expected_boxes_shape = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_boxes_shape) + + expectations = Expectations( + { + ("cpu", None): [0.25576, 0.55051, 0.47765, 0.87141, 0.76966], + } + ) + expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(outputs.pred_boxes.flatten()[:5], expected_boxes, rtol=2e-3, atol=2e-3) + + results = image_processor.post_process_object_detection( + outputs, threshold=0.0, target_sizes=[image.size[::-1]] + )[0] + + expectations = Expectations( + { + ("cpu", None): [0.9820, 0.9874, 0.9918, 0.9696], + } + ) + expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device) + + expected_labels = [17, 17, 75, 75] + + expectations = Expectations( + { + ("cpu", None): [ + [10.8431, 55.1121, 316.5317, 473.3869], + [345.1129, 24.5076, 640.0582, 373.1581], + [40.3736, 73.1451, 175.8807, 117.5796], + [333.9091, 76.8915, 370.1848, 186.7155], + ], + } + ) + expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(results["scores"][:4], expected_scores, atol=1e-3, rtol=2e-4) + self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels) + torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=2e-4) + + @slow + def test_inference_segmentation(self): + size = "segmentation" + model = RfDetrForInstanceSegmentation.from_pretrained(CHECKPOINT[size], attn_implementation="eager").to( + torch_device + ) + + image_processor = self.default_image_processor[size] + image = prepare_img() + encoding = image_processor(images=image, return_tensors="pt").to(torch_device) + pixel_values = encoding["pixel_values"].to(torch_device) + pixel_mask = encoding["pixel_mask"].to(torch_device) + with torch.no_grad(): + outputs = model(pixel_values, pixel_mask) + + expected_logits_shape = torch.Size((1, model.config.num_queries, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_logits_shape) + + score_expectations = Expectations( + { + ("cpu", None): [-7.05877, -4.23362, -6.54288, -8.13384, -6.36838], + } + ) + expected_logits = torch.tensor(score_expectations.get_expectation()).to(torch_device) + torch.testing.assert_close(outputs.logits.flatten()[:5], expected_logits, rtol=2e-3, atol=2e-3) + + expected_boxes_shape = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_boxes_shape) + + score_expectations = Expectations( + { + ("cpu", None): [0.25603, 0.55164, 0.48340, 0.87798, 0.73861], + } + ) + expected_boxes = torch.tensor(score_expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(outputs.pred_boxes.flatten()[:5], expected_boxes, rtol=2e-3, atol=2e-3) + + expected_masks_shape = torch.Size( + ( + 1, + model.config.num_queries, + pixel_values.shape[-2] // model.config.mask_downsample_ratio, + pixel_values.shape[-1] // model.config.mask_downsample_ratio, + ) + ) + self.assertEqual(outputs.pred_masks.shape, expected_masks_shape) + + score_expectations = Expectations( + { + ("cpu", None): [-16.72129, -16.17153, -17.06426, -17.29409, -17.78559], + } + ) + expected_masks = torch.tensor(score_expectations.get_expectation()).to(torch_device) + + torch.testing.assert_close(outputs.pred_masks.flatten()[:5], expected_masks, rtol=2e-3, atol=2e-3) + + results = image_processor.post_process_instance_segmentation( + outputs, threshold=0.0, target_sizes=[image.size[::-1]] + )[0] + + expected_labels = [17, 75, 75, 17] + score_expectations = Expectations( + { + ("cpu", None): [0.98604, 0.985644, 0.950492, 0.967915], + } + ) + expected_scores = torch.tensor(score_expectations.get_expectation()).to(torch.float32) + + predicted_labels = [segment["label_id"] for segment in results["segments_info"]] + predicted_scores = torch.tensor([segment["score"] for segment in results["segments_info"]]) + + self.assertSequenceEqual(predicted_labels, expected_labels) + torch.testing.assert_close(predicted_scores, expected_scores, atol=1e-3, rtol=2e-4) + + +class RfDetrDinov2ModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=32, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + mask_ratio=0.5, + num_windows=2, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + + # in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) + self.mask_length = num_patches + self.num_windows = num_windows + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return RfDetrDinov2Config( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + num_windows=self.num_windows, + ) + + def create_and_check_backbone(self, config, pixel_values): + model = RfDetrDinov2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify hidden states + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + expected_size = self.image_size // config.patch_size + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size] + ) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + + # verify backbone works with out_features=None + config.out_features = None + model = RfDetrDinov2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size] + ) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + + # verify backbone works with apply_layernorm=False and reshape_hidden_states=False + config.apply_layernorm = False + config.reshape_hidden_states = False + + model = RfDetrDinov2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, self.seq_length, self.hidden_size] + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class RfDetrDinov2BackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (RfDetrDinov2Backbone,) if is_torch_available() else () + config_class = RfDetrDinov2Config + + has_attentions = False + + def setUp(self): + self.model_tester = RfDetrDinov2ModelTester(self) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 996e14c5e7f5..33770b210efc 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -125,6 +125,7 @@ "IdeficsPerceiverConfig": True, "GptOssConfig": True, "LwDetrConfig": True, + "RfDetrConfig": True, } # Common and important attributes, even if they do not always appear in the modeling files (can be a regex pattern)