diff --git a/docs/source/en/model_doc/maskformer.md b/docs/source/en/model_doc/maskformer.md index 835eb7f29e8e..e1fe6574ae9b 100644 --- a/docs/source/en/model_doc/maskformer.md +++ b/docs/source/en/model_doc/maskformer.md @@ -64,6 +64,10 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The [[autodoc]] models.maskformer.modeling_maskformer.MaskFormerForInstanceSegmentationOutput +## MaskFormerDetrConfig + +[[autodoc]] MaskFormerDetrConfig + ## MaskFormerConfig [[autodoc]] MaskFormerConfig diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 1e1d8fd98492..04bcc7e10190 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -55,7 +55,6 @@ "qwen3_omni_moe": "qwen2_moe", "qwen3_omni_moe_thinker": "qwen2_moe", "qwen3_next": "qwen2_moe", - "qwen3_5_moe": "qwen2_moe", "hunyuan_v1_moe": "qwen2_moe", "flex_olmo": "qwen2_moe", "olmoe": "qwen2_moe", @@ -91,7 +90,6 @@ def _build_checkpoint_conversion_mapping(): ], "colpali": [ WeightRenaming(source_patterns=r"vlm(?!\.model)", target_patterns="vlm.model"), - WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), ], "emu3": [ WeightRenaming(source_patterns=r"text_model.model", target_patterns="text_model"), @@ -109,20 +107,16 @@ def _build_checkpoint_conversion_mapping(): source_patterns=r"(? list[WeightConverter | WeightRenaming] | None: + model_type = getattr(model.config, "model_type", None) + if model_type is not None: + model_specific_conversions = get_checkpoint_conversion_mapping(model_type) + return model_specific_conversions + return None def get_model_conversion_mapping( @@ -517,28 +545,35 @@ def get_model_conversion_mapping( For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming `_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping. """ + # Lazy import to avoid circular import issues + from .modeling_utils import PreTrainedModel + # note: this function is used in PEFT, so changing the API requires coordination weight_conversions = [] # Load models with explicit, user-provided key mapping if key_mapping is not None: weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()] - elif any( - allowed_name in class_name.__name__.lower() - for class_name in model.__class__.__mro__[:-1] - for allowed_name in VLMS - ): - weight_conversions = [ - WeightRenaming(source_patterns=k, target_patterns=v) - for k, v in model._checkpoint_conversion_mapping.items() - ] - # TODO: should be checked recursively on submodels!! - model_type = getattr(model.config, "model_type", None) - if model_type is not None: - model_specific_conversions = get_checkpoint_conversion_mapping(model_type) - if model_specific_conversions is not None: - weight_conversions.extend(model_specific_conversions) + # Model have several `PreTrainedModel` within with the same model type + # For ex: XForConditionalGeneration -> XModel. We don't want to apply the same + # conversion pattern twice because of that + seen_model_types = set() + if (conversions := extract_weight_conversions_for_model(model)) is not None: + weight_conversions.extend(conversions) + seen_model_types.add(model.config.model_type) + + # Recurse over submodules and collect all conversions + for submodule in model.modules(): + if ( + submodule is not model + and isinstance(submodule, PreTrainedModel) + and submodule.config.model_type not in seen_model_types + ): + conversions = extract_weight_conversions_for_model(submodule) + if conversions is not None: + weight_conversions.extend(conversions) + seen_model_types.add(submodule.config.model_type) if add_legacy: weight_conversions.extend(get_checkpoint_conversion_mapping("legacy")) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1cdb033cb709..f50774ef8065 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4100,7 +4100,7 @@ def from_pretrained( # instantiated model, as the flags can be modified by instances sometimes) dtype_plan = model._get_dtype_plan(dtype) - # Obtain the weight conversion mapping for this model if any are registered + # Obtain the weight conversion mapping for this model if any are registered and apply to all submodels recursively weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer) if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 44e1893c7b75..2f10c81b38e1 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1621,26 +1621,6 @@ def _set_aux_loss(self, outputs_class, outputs_coord): """ ) class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel): - _checkpoint_conversion_mapping = { - "bbox_attention.q_linear": "bbox_attention.q_proj", - "bbox_attention.k_linear": "bbox_attention.k_proj", - # Mask head refactor - "mask_head.lay1": "mask_head.conv1.conv", - "mask_head.gn1": "mask_head.conv1.norm", - "mask_head.lay2": "mask_head.conv2.conv", - "mask_head.gn2": "mask_head.conv2.norm", - "mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter", - "mask_head.lay3": "mask_head.fpn_stages.0.refine.conv", - "mask_head.gn3": "mask_head.fpn_stages.0.refine.norm", - "mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter", - "mask_head.lay4": "mask_head.fpn_stages.1.refine.conv", - "mask_head.gn4": "mask_head.fpn_stages.1.refine.norm", - "mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter", - "mask_head.lay5": "mask_head.fpn_stages.2.refine.conv", - "mask_head.gn5": "mask_head.fpn_stages.2.refine.norm", - "mask_head.out_lay": "mask_head.output_conv", - } - def __init__(self, config: ConditionalDetrConfig): super().__init__(config) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 4906b3510f44..384cc388cfd7 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1435,26 +1435,6 @@ def forward( """ ) class DetrForSegmentation(DetrPreTrainedModel): - _checkpoint_conversion_mapping = { - "bbox_attention.q_linear": "bbox_attention.q_proj", - "bbox_attention.k_linear": "bbox_attention.k_proj", - # Mask head refactor - "mask_head.lay1": "mask_head.conv1.conv", - "mask_head.gn1": "mask_head.conv1.norm", - "mask_head.lay2": "mask_head.conv2.conv", - "mask_head.gn2": "mask_head.conv2.norm", - "mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter", - "mask_head.lay3": "mask_head.fpn_stages.0.refine.conv", - "mask_head.gn3": "mask_head.fpn_stages.0.refine.norm", - "mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter", - "mask_head.lay4": "mask_head.fpn_stages.1.refine.conv", - "mask_head.gn4": "mask_head.fpn_stages.1.refine.norm", - "mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter", - "mask_head.lay5": "mask_head.fpn_stages.2.refine.conv", - "mask_head.gn5": "mask_head.fpn_stages.2.refine.norm", - "mask_head.out_lay": "mask_head.output_conv", - } - def __init__(self, config: DetrConfig): super().__init__(config) diff --git a/src/transformers/models/maskformer/configuration_maskformer.py b/src/transformers/models/maskformer/configuration_maskformer.py index cb7650e59577..abfe6ae0154e 100644 --- a/src/transformers/models/maskformer/configuration_maskformer.py +++ b/src/transformers/models/maskformer/configuration_maskformer.py @@ -1,4 +1,10 @@ -# Copyright 2022 Meta Platforms, Inc.and The HuggingFace Inc. team. All rights reserved. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/maskformer/modular_maskformer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_maskformer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2022 Meta Platforms, Inc.s and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,20 +17,107 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""MaskFormer model configuration""" - from huggingface_hub.dataclasses import strict from ...backbone_utils import consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ...utils import auto_docstring, logging from ..auto import CONFIG_MAPPING, AutoConfig -from ..detr import DetrConfig logger = logging.get_logger(__name__) +@auto_docstring(checkpoint="facebook/maskformer-swin-base-ade") +@strict +class MaskFormerDetrConfig(PreTrainedConfig): + r""" + num_queries (`int`, *optional*, defaults to 100): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`ConditionalDetrModel`] can detect in a single image. For COCO, we recommend 100 queries. + position_embedding_type (`str`, *optional*, defaults to `"sine"`): + Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. + dilation (`bool`, *optional*, defaults to `False`): + Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when + `use_timm_backbone` = `True`. + + Examples: + + ```python + >>> from transformers import MaskFormerDetrConfig, MaskFormerDetrModel + + >>> # Initializing a MASK_FORMER_DETR facebook/mask_former_detr-resnet-50 style configuration + >>> configuration = MaskFormerDetrConfig() + + >>> # Initializing a model (with random weights) from the facebook/mask_former_detr-resnet-50 style configuration + >>> model = MaskFormerDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "detr" + sub_configs = {"backbone_config": AutoConfig} + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + "num_hidden_layers": "encoder_layers", + } + + backbone_config: dict | PreTrainedConfig | None = None + num_channels: int = 3 + num_queries: int = 100 + encoder_layers: int = 6 + encoder_ffn_dim: int = 2048 + encoder_attention_heads: int = 8 + decoder_layers: int = 6 + decoder_ffn_dim: int = 2048 + decoder_attention_heads: int = 8 + encoder_layerdrop: float | int = 0.0 + decoder_layerdrop: float | int = 0.0 + is_encoder_decoder: bool = True + activation_function: str = "relu" + d_model: int = 256 + dropout: float | int = 0.1 + attention_dropout: float | int = 0.0 + activation_dropout: float | int = 0.0 + init_std: float = 0.02 + init_xavier_std: float = 1.0 + auxiliary_loss: bool = False + position_embedding_type: str = "sine" + dilation: bool = False + class_cost: int = 1 + bbox_cost: int = 5 + giou_cost: int = 2 + mask_loss_coefficient: int = 1 + dice_loss_coefficient: int = 1 + bbox_loss_coefficient: int = 5 + giou_loss_coefficient: int = 2 + eos_coefficient: float = 0.1 + + def __post_init__(self, **kwargs): + backbone_kwargs = kwargs.get("backbone_kwargs", {}) + timm_default_kwargs = { + "num_channels": backbone_kwargs.get("num_channels", self.num_channels), + "features_only": True, + "use_pretrained_backbone": False, + "out_indices": backbone_kwargs.get("out_indices", [1, 2, 3, 4]), + } + if self.dilation: + timm_default_kwargs["output_stride"] = backbone_kwargs.get("output_stride", 16) + + self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config( + backbone_config=self.backbone_config, + default_backbone="resnet50", + default_config_type="resnet", + default_config_kwargs={"out_features": ["stage4"]}, + timm_default_kwargs=timm_default_kwargs, + **kwargs, + ) + super().__post_init__(**kwargs) + + @auto_docstring(checkpoint="facebook/maskformer-swin-base-ade") @strict class MaskFormerConfig(PreTrainedConfig): @@ -108,7 +201,7 @@ def __post_init__(self, **kwargs): if self.decoder_config is None: # fall back to https://huggingface.co/facebook/detr-resnet-50 - self.decoder_config = DetrConfig() + self.decoder_config = MaskFormerDetrConfig() else: # verify that the decoder is supported decoder_type = ( @@ -130,4 +223,4 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -__all__ = ["MaskFormerConfig"] +__all__ = ["MaskFormerConfig", "MaskFormerDetrConfig"] diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 9e565eed356e..788775a52fcb 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/maskformer/modular_maskformer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_maskformer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2022 Meta Platforms, Inc.s and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,9 +17,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch MaskFormer model.""" import math +from collections.abc import Callable from dataclasses import dataclass from numbers import Number @@ -23,11 +29,10 @@ from ... import initialization as init from ...activations import ACT2FN -from ...backbone_utils import load_backbone from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithCrossAttentions -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ( @@ -36,11 +41,12 @@ auto_docstring, is_accelerate_available, is_scipy_available, - logging, requires_backends, ) -from ..detr import DetrConfig -from .configuration_maskformer import MaskFormerConfig +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..auto import AutoBackbone +from .configuration_maskformer import MaskFormerConfig, MaskFormerDetrConfig from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -51,18 +57,15 @@ if is_scipy_available(): from scipy.optimize import linear_sum_assignment -logger = logging.get_logger(__name__) - @dataclass @auto_docstring( custom_intro=""" - Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions, + Base class for outputs of the MASK_FORMER_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding losses. """ ) -# Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput class DetrDecoderOutput(BaseModelOutputWithCrossAttentions): r""" cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): @@ -229,364 +232,312 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput): attentions: tuple[torch.FloatTensor] | None = None -def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor: - """ - An utility function that upsamples `pixel_values` to match the dimension of `like`. - - Args: - pixel_values (`torch.Tensor`): - The tensor we wish to upsample. - like (`torch.Tensor`): - The tensor we wish to use as size target. - mode (str, *optional*, defaults to `"bilinear"`): - The interpolation mode. - - Returns: - `torch.Tensor`: The upsampled tensor +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the MASK_FORMER_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions, + namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them + gone through a layernorm. This is useful when training the model with auxiliary decoding losses. """ - _, _, height, width = like.shape - upsampled = nn.functional.interpolate(pixel_values, size=(height, width), mode=mode, align_corners=False) - return upsampled - - -# refactored from original implementation -def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: +) +class MaskFormerDetrDecoderOutput(BaseModelOutputWithCrossAttentions): r""" - Compute the DICE loss, similar to generalized IOU for masks as follows: - - $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ - - In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + """ - $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + intermediate_hidden_states: torch.FloatTensor | None = None - Args: - inputs (`torch.Tensor`): - A tensor representing a mask. - labels (`torch.Tensor`): - A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs - (0 for the negative class and 1 for the positive class). - num_masks (`int`): - The number of masks present in the current batch, used for normalization. - Returns: - `torch.Tensor`: The computed loss. +class MaskFormerDetrLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. """ - probs = inputs.sigmoid().flatten(1) - numerator = 2 * (probs * labels).sum(-1) - denominator = probs.sum(-1) + labels.sum(-1) - loss = 1 - (numerator + 1) / (denominator + 1) - loss = loss.sum() / num_masks - return loss + def __init__(self, embedding_dim=256): + super().__init__() + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) -# refactored from original implementation -def sigmoid_focal_loss( - inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2 -) -> Tensor: - r""" - Focal loss proposed in [Focal Loss for Dense Object Detection](https://huggingface.co/papers/1708.02002) originally used in - RetinaNet. The loss is computed as follows: + @compile_compatible_method_lru_cache(maxsize=1) + def forward( + self, + shape: torch.Size, + device: torch.device | str, + dtype: torch.dtype, + mask: torch.Tensor | None = None, + ): + height, width = shape[-2:] + width_values = torch.arange(width, device=device) + height_values = torch.arange(height, device=device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(shape[0], 1, 1, 1) + # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format + # expected by the encoder + pos = pos.flatten(2).permute(0, 2, 1) + return pos - $$ \mathcal{L}_{\text{focal loss} = -(1 - p_t)^{\gamma}\log{(p_t)} $$ - where \\(CE(p_t) = -\log{(p_t)}}\\), CE is the standard Cross Entropy Loss +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 - Please refer to equation (1,2,3) of the paper for a better understanding. + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - Args: - inputs (`torch.Tensor`): - A float tensor of arbitrary shape. - labels (`torch.Tensor`): - A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs - (0 for the negative class and 1 for the positive class). - num_masks (`int`): - The number of masks present in the current batch, used for normalization. - alpha (float, *optional*, defaults to 0.25): - Weighting factor in range (0,1) to balance positive vs negative examples. - gamma (float, *optional*, defaults to 2.0): - Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - Returns: - `torch.Tensor`: The computed loss. - """ - criterion = nn.BCEWithLogitsLoss(reduction="none") - probs = inputs.sigmoid() - cross_entropy_loss = criterion(inputs, labels) - p_t = probs * labels + (1 - probs) * (1 - labels) - loss = cross_entropy_loss * ((1 - p_t) ** gamma) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - if alpha >= 0: - alpha_t = alpha * labels + (1 - alpha) * (1 - labels) - loss = alpha_t * loss + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() - loss = loss.mean(1).sum() / num_masks - return loss + return attn_output, attn_weights -# refactored from original implementation -def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: +class MaskFormerDetrSelfAttention(nn.Module): """ - A pair wise version of the dice loss, see `dice_loss` for usage. + Multi-headed self-attention from 'Attention Is All You Need' paper. - Args: - inputs (`torch.Tensor`): - A tensor representing a mask - labels (`torch.Tensor`): - A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs - (0 for the negative class and 1 for the positive class). - - Returns: - `torch.Tensor`: The computed loss between each pairs. + In MASK_FORMER_DETR, position embeddings are added to both queries and keys (but not values) in self-attention. """ - inputs = inputs.sigmoid().flatten(1) - numerator = 2 * torch.matmul(inputs, labels.T) - # using broadcasting to get a [num_queries, NUM_CLASSES] matrix - denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] - loss = 1 - (numerator + 1) / (denominator + 1) - return loss - -# refactored from original implementation -def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor: - r""" - A pair wise version of the focal loss, see `sigmoid_focal_loss` for usage. - - Args: - inputs (`torch.Tensor`): - A tensor representing a mask. - labels (`torch.Tensor`): - A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs - (0 for the negative class and 1 for the positive class). - alpha (float, *optional*, defaults to 0.25): - Weighting factor in range (0,1) to balance positive vs negative examples. - gamma (float, *optional*, defaults to 2.0): - Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + def __init__( + self, + config: MaskFormerDetrConfig, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.config = config + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - Returns: - `torch.Tensor`: The computed loss between each pairs. - """ - if alpha < 0: - raise ValueError("alpha must be positive") + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) - height_and_width = inputs.shape[1] + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings are added to both queries and keys (but not values). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - criterion = nn.BCEWithLogitsLoss(reduction="none") - prob = inputs.sigmoid() - cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) - focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos - focal_pos *= alpha + query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states - cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - focal_neg = (prob**gamma) * cross_entropy_loss_neg - focal_neg *= 1 - alpha + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) - loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) - return loss / height_and_width + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights -# TODO: use modular - Copied from transformers.models.detr.modeling_detr.DetrAttention -class DetrAttention(nn.Module): +class MaskFormerDetrCrossAttention(nn.Module): """ - Multi-headed attention from 'Attention Is All You Need' paper. + Multi-headed cross-attention from 'Attention Is All You Need' paper. - Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + In MASK_FORMER_DETR, queries get their own position embeddings, while keys get encoder position embeddings. + Values don't get any position embeddings. """ def __init__( self, - embed_dim: int, - num_heads: int, + config: MaskFormerDetrConfig, + hidden_size: int, + num_attention_heads: int, dropout: float = 0.0, bias: bool = True, ): super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - if self.head_dim * num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {num_heads})." - ) + self.config = config + self.head_dim = hidden_size // num_attention_heads self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): - return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def with_pos_embed(self, tensor: torch.Tensor, object_queries: Tensor | None): - return tensor if object_queries is None else tensor + object_queries + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) def forward( self, hidden_states: torch.Tensor, + key_value_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - object_queries: torch.Tensor | None = None, - key_value_states: torch.Tensor | None = None, - spatial_position_embeddings: torch.Tensor | None = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - """Input shape: Batch x Time x Channel""" - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size, target_len, embed_dim = hidden_states.size() - - # add position embeddings to the hidden states before projecting to queries and keys - if object_queries is not None: - hidden_states_original = hidden_states - hidden_states = self.with_pos_embed(hidden_states, object_queries) - - # add key-value position embeddings to the key value states - if spatial_position_embeddings is not None: - key_value_states_original = key_value_states - key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) - value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) - value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) - - proj_shape = (batch_size * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - source_len = key_states.size(1) + position_embeddings: torch.Tensor | None = None, + encoder_position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings logic: + - Queries get position_embeddings + - Keys get encoder_position_embeddings + - Values don't get any position embeddings + """ + query_input_shape = hidden_states.shape[:-1] + query_hidden_shape = (*query_input_shape, -1, self.head_dim) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + kv_input_shape = key_value_states.shape[:-1] + kv_hidden_shape = (*kv_input_shape, -1, self.head_dim) - if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): - raise ValueError( - f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" - f" {attn_weights.size()}" - ) + query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + key_input = ( + key_value_states + encoder_position_embeddings + if encoder_position_embeddings is not None + else key_value_states + ) - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, target_len, source_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" - f" {attention_mask.size()}" - ) - if attention_mask.dtype == torch.bool: - attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( - attention_mask, -torch.inf - ) - attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask - attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) - else: - attn_weights_reshaped = None + query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2) + key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2) - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) - attn_output = torch.bmm(attn_probs, value_states) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) - if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + attn_output = attn_output.reshape(*query_input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, target_len, embed_dim) - attn_output = self.out_proj(attn_output) +class MaskFormerDetrMLP(nn.Module): + def __init__(self, config: MaskFormerDetrConfig, hidden_size: int, intermediate_size: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, hidden_size) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.dropout = config.dropout - return attn_output, attn_weights_reshaped + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states -# TODO: use modular - Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer -class DetrDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: DetrConfig): +class MaskFormerDetrDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MaskFormerDetrConfig): super().__init__() - self.embed_dim = config.d_model + self.hidden_size = config.d_model - self.self_attn = DetrAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, + self.self_attn = MaskFormerDetrSelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = DetrAttention( - self.embed_dim, - config.decoder_attention_heads, + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.encoder_attn = MaskFormerDetrCrossAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.mlp = MaskFormerDetrMLP(config, self.hidden_size, config.decoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, - object_queries: torch.Tensor | None = None, - query_position_embeddings: torch.Tensor | None = None, + spatial_position_embeddings: torch.Tensor | None = None, + object_queries_position_embeddings: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, - output_attentions: bool | None = False, **kwargs: Unpack[TransformersKwargs], - ): + ) -> torch.Tensor: """ Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - object_queries (`torch.FloatTensor`, *optional*): - object_queries that are added to the hidden states - in the cross-attention layer. - query_position_embeddings (`torch.FloatTensor`, *optional*): - position embeddings that are added to the queries and keys - in the self-attention layer. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only + in the cross-attention layer (not to values). + object_queries_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings for the object query slots. In self-attention, these are added to both queries + and keys (not values). In cross-attention, these are added to queries only (not to keys or values). encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. """ residual = hidden_states # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, + position_embeddings=object_queries_position_embeddings, attention_mask=attention_mask, - output_attentions=output_attentions, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -594,17 +545,16 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states, _ = self.encoder_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - spatial_position_embeddings=object_queries, - output_attentions=output_attentions, + position_embeddings=object_queries_position_embeddings, + encoder_position_embeddings=spatial_position_embeddings, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -613,66 +563,256 @@ def forward( # Fully Connected residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - outputs = (hidden_states,) + return hidden_states - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - return outputs +class MaskFormerDetrConvBlock(nn.Module): + """Basic conv block: Conv3x3 -> GroupNorm -> Activation.""" + def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.norm = nn.GroupNorm(min(8, out_channels), out_channels) + self.activation = ACT2FN[activation] -class DetrDecoder(PreTrainedModel): - config: DetrConfig - base_model_prefix = "model" + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.activation(self.norm(self.conv(x))) + + +class MaskFormerDetrFPNFusionStage(nn.Module): + """Single FPN fusion stage combining low-resolution features with high-resolution FPN features.""" + + def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"): + super().__init__() + self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1) + self.refine = MaskFormerDetrConvBlock(current_channels, output_channels, activation) + + def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in) + fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out) + + Returns: + Fused and refined features, shape (B*Q, output_channels, H_out, W_out) + """ + fpn_features = self.fpn_adapter(fpn_features) + features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest") + return self.refine(fpn_features + features) + +class MaskFormerDetrMaskHeadSmallConv(nn.Module): + """ + Segmentation mask head that generates per-query masks using FPN-based progressive upsampling. + + Combines attention maps (spatial localization) with encoder features (semantics) and progressively + upsamples through multiple scales, fusing with FPN features for high-resolution detail. """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. - The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + def __init__( + self, + input_channels: int, + fpn_channels: list[int], + hidden_size: int, + activation_function: str = "relu", + ): + super().__init__() + if input_channels % 8 != 0: + raise ValueError(f"input_channels must be divisible by 8, got {input_channels}") + + self.conv1 = MaskFormerDetrConvBlock(input_channels, input_channels, activation_function) + self.conv2 = MaskFormerDetrConvBlock(input_channels, hidden_size // 2, activation_function) + + # Progressive channel reduction: /2 -> /4 -> /8 -> /16 + self.fpn_stages = nn.ModuleList( + [ + MaskFormerDetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function), + MaskFormerDetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function), + MaskFormerDetrFPNFusionStage( + fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function + ), + ] + ) + + self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1) + + def forward( + self, + features: torch.Tensor, + attention_masks: torch.Tensor, + fpn_features: list[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + features: Encoder output features, shape (batch_size, hidden_size, H, W) + attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W) + fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W) + + Returns: + Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W) + """ + num_queries = attention_masks.shape[1] + + # Expand to (batch_size * num_queries) dimension + features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) + attention_masks = attention_masks.flatten(0, 1) + fpn_features = [ + fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features + ] + + hidden_states = torch.cat([features, attention_masks], dim=1) + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + + for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features): + hidden_states = fpn_stage(hidden_states, fpn_feat) + + return self.output_conv(hidden_states) + + +class MaskFormerDetrMHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + + def forward( + self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None + ): + query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim) + key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:]) + + query_states = self.q_proj(query_states).view(query_hidden_shape) + key_states = nn.functional.conv2d( + key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias + ).view(key_hidden_shape) + + batch_size, num_queries, num_heads, head_dim = query_states.shape + _, _, _, height, width = key_states.shape + query_shape = (batch_size * num_heads, num_queries, head_dim) + key_shape = (batch_size * num_heads, height * width, head_dim) + attn_weights_shape = (batch_size, num_heads, num_queries, height, width) + + query = query_states.transpose(1, 2).contiguous().view(query_shape) + key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape) + + attn_weights = ( + (torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2) + ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size()) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + return attn_weights + + +@auto_docstring +class MaskFormerDetrPreTrainedModel(PreTrainedModel): + config: MaskFormerDetrConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + input_modalities = ("image",) + _no_split_modules = [r"MaskFormerDetrConvEncoder", r"MaskFormerDetrEncoderLayer", r"MaskFormerDetrDecoderLayer"] + supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn = True + _supports_attention_backend = True + _supports_flex_attn = True # Uses create_bidirectional_masks for attention masking + _keys_to_ignore_on_load_unexpected = [ + r"mask_former_detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked" + ] + + @torch.no_grad() + def _init_weights(self, module): + std = self.config.init_std + xavier_std = self.config.init_xavier_std + + if isinstance(module, MaskFormerDetrMaskHeadSmallConv): + # MaskFormerDetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_uniform_(m.weight, a=1) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(module, MaskFormerDetrMHAttentionMap): + init.zeros_(module.k_proj.bias) + init.zeros_(module.q_proj.bias) + init.xavier_uniform_(module.k_proj.weight, gain=xavier_std) + init.xavier_uniform_(module.q_proj.weight, gain=xavier_std) + elif isinstance(module, MaskFormerDetrLearnedPositionEmbedding): + init.uniform_(module.row_embeddings.weight) + init.uniform_(module.column_embeddings.weight) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + init.ones_(module.weight) + init.zeros_(module.bias) - Some small tweaks for DETR: - - object_queries and query_position_embeddings are added to the forward pass. - - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. +class MaskFormerDetrDecoder(MaskFormerDetrPreTrainedModel): + """ + Transformer decoder that refines a set of object queries. It is composed of a stack of [`MaskFormerDetrDecoderLayer`] modules, + which apply self-attention to the queries and cross-attention to the encoder's outputs. Args: - config: DetrConfig + config (`MaskFormerDetrConfig`): Model configuration object. """ - def __init__(self, config: DetrConfig): - super().__init__(config) + _can_record_outputs = { + "hidden_states": MaskFormerDetrDecoderLayer, + "attentions": MaskFormerDetrSelfAttention, + "cross_attentions": MaskFormerDetrCrossAttention, + } + def __init__(self, config: MaskFormerDetrConfig): + super().__init__(config) self.dropout = config.dropout - self.layerdrop = config.decoder_layerdrop - self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)]) - # in DETR, the decoder uses layernorm after the last decoder layer output + self.layers = nn.ModuleList([MaskFormerDetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + # in MASK_FORMER_DETR, the decoder uses layernorm after the last decoder layer output self.layernorm = nn.LayerNorm(config.d_model) - self.gradient_checkpointing = False - + # Initialize weights and apply final processing self.post_init() + @merge_with_config_defaults + @capture_outputs def forward( self, inputs_embeds=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - object_queries=None, - query_position_embeddings=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + spatial_position_embeddings=None, + object_queries_position_embeddings=None, **kwargs: Unpack[TransformersKwargs], - ): + ) -> MaskFormerDetrDecoderOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -695,99 +835,122 @@ def forward( - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Position embeddings that are added to the queries and keys in each cross-attention layer. - query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): - , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer. + object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict if inputs_embeds is not None: hidden_states = inputs_embeds - encoder_attention_mask = create_bidirectional_mask( - config=self.config, - inputs_embeds=inputs_embeds, - attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - ) + if attention_mask is not None: + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=attention_mask, + ) + + # expand encoder attention mask (for cross-attention on encoder outputs) + if encoder_hidden_states is not None and encoder_attention_mask is not None: + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + ) # optional intermediate hidden states intermediate = () if self.config.auxiliary_loss else None # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: - continue - - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, - None, # attention_mask - object_queries, - query_position_embeddings, + attention_mask, + spatial_position_embeddings, + object_queries_position_embeddings, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, **kwargs, ) - hidden_states = layer_outputs[0] - if self.config.auxiliary_loss: hidden_states = self.layernorm(hidden_states) intermediate += (hidden_states,) - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - # finally, apply layernorm hidden_states = self.layernorm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - # stack intermediate decoder activations if self.config.auxiliary_loss: intermediate = torch.stack(intermediate) - if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate] - if v is not None - ) - return DetrDecoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - intermediate_hidden_states=intermediate, - ) + return MaskFormerDetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate) + + +# refactored from original implementation +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * torch.matmul(inputs, labels.T) + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +# refactored from original implementation +def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor: + r""" + A pair wise version of the focal loss, see `sigmoid_focal_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha (float, *optional*, defaults to 0.25): + Weighting factor in range (0,1) to balance positive vs negative examples. + gamma (float, *optional*, defaults to 2.0): + Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + if alpha < 0: + raise ValueError("alpha must be positive") + + height_and_width = inputs.shape[1] + + criterion = nn.BCEWithLogitsLoss(reduction="none") + prob = inputs.sigmoid() + cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) + focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos + focal_pos *= alpha + + cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + + focal_neg = (prob**gamma) * cross_entropy_loss_neg + focal_neg *= 1 - alpha + + loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T) + + return loss / height_and_width # refactored from original implementation @@ -888,6 +1051,81 @@ def __repr__(self): return "\n".join(lines) +# refactored from original implementation +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +# refactored from original implementation +def sigmoid_focal_loss( + inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2 +) -> Tensor: + r""" + Focal loss proposed in [Focal Loss for Dense Object Detection](https://huggingface.co/papers/1708.02002) originally used in + RetinaNet. The loss is computed as follows: + + $$ \mathcal{L}_{\text{focal loss} = -(1 - p_t)^{\gamma}\log{(p_t)} $$ + + where \\(CE(p_t) = -\log{(p_t)}}\\), CE is the standard Cross Entropy Loss + + Please refer to equation (1,2,3) of the paper for a better understanding. + + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + alpha (float, *optional*, defaults to 0.25): + Weighting factor in range (0,1) to balance positive vs negative examples. + gamma (float, *optional*, defaults to 2.0): + Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + + Returns: + `torch.Tensor`: The computed loss. + """ + criterion = nn.BCEWithLogitsLoss(reduction="none") + probs = inputs.sigmoid() + cross_entropy_loss = criterion(inputs, labels) + p_t = probs * labels + (1 - probs) * (1 - labels) + loss = cross_entropy_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * labels + (1 - alpha) * (1 - labels) + loss = alpha_t * loss + + loss = loss.mean(1).sum() / num_masks + return loss + + # copied and adapted from original implementation class MaskFormerLoss(nn.Module): def __init__( @@ -1352,7 +1590,7 @@ def __init__(self, config: MaskFormerConfig): backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] config.backbone_config = backbone_config - self.encoder = load_backbone(config) + self.encoder = AutoBackbone.from_config(config=config.backbone_config) feature_channels = self.encoder.channels self.decoder = MaskFormerPixelDecoder( @@ -1397,7 +1635,7 @@ def __init__(self, in_features: int, config: MaskFormerConfig): self.position_embedder = MaskFormerSinePositionEmbedding(num_pos_feats=hidden_size // 2, normalize=True) self.queries_embedder = nn.Embedding(config.decoder_config.num_queries, hidden_size) self.input_projection = nn.Conv2d(in_features, hidden_size, kernel_size=1) if should_project else None - self.decoder = DetrDecoder(config=config.decoder_config) + self.decoder = MaskFormerDetrDecoder(config=config.decoder_config) def forward( self, @@ -1428,8 +1666,8 @@ def forward( attention_mask=None, encoder_hidden_states=image_features, encoder_attention_mask=None, - object_queries=object_queries, - query_position_embeddings=queries_embeddings, + spatial_position_embeddings=object_queries, + object_queries_position_embeddings=queries_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1825,4 +2063,9 @@ def forward( ) -__all__ = ["MaskFormerForInstanceSegmentation", "MaskFormerModel", "MaskFormerPreTrainedModel"] +__all__ = [ + "MaskFormerForInstanceSegmentation", + "MaskFormerModel", + "MaskFormerPreTrainedModel", + "MaskFormerDetrPreTrainedModel", +] diff --git a/src/transformers/models/maskformer/modular_maskformer.py b/src/transformers/models/maskformer/modular_maskformer.py new file mode 100644 index 000000000000..06705906c891 --- /dev/null +++ b/src/transformers/models/maskformer/modular_maskformer.py @@ -0,0 +1,1526 @@ +# Copyright 2022 Meta Platforms, Inc.s and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MaskFormer model.""" + +import math +from dataclasses import dataclass +from numbers import Number + +import numpy as np +import torch +from huggingface_hub.dataclasses import strict +from torch import Tensor, nn + +from ... import initialization as init +from ...backbone_utils import consolidate_backbone_kwargs_to_config +from ...configuration_utils import PreTrainedConfig +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ( + ModelOutput, + auto_docstring, + is_accelerate_available, + is_scipy_available, + logging, + requires_backends, +) +from ..auto import CONFIG_MAPPING, AutoBackbone, AutoConfig +from ..detr.configuration_detr import DetrConfig +from ..detr.modeling_detr import DetrDecoder, DetrDecoderOutput +from .configuration_maskformer_swin import MaskFormerSwinConfig + + +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="facebook/maskformer-swin-base-ade") +@strict +class MaskFormerDetrConfig(DetrConfig): + model_type = "detr" + + +@auto_docstring(checkpoint="facebook/maskformer-swin-base-ade") +@strict +class MaskFormerConfig(PreTrainedConfig): + r""" + fpn_feature_size (`int`, *optional*, defaults to 256): + The Feature Pyramid Network's features size. + mask_feature_size (`int`, *optional*, defaults to 256): + The masks' features size, this value will also be used to specify the Feature Pyramid Network features' + size. + decoder_config (`Dict`, *optional*): + The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50` + will be used. + cross_entropy_weight (`float`, *optional*, defaults to 1.0): + The weight for the cross entropy loss. + output_auxiliary_logits (`bool`, *optional*): + Should the model output its `auxiliary_logits` or not. + + Raises: + `ValueError`: + Raised if the backbone model type selected is not in `["swin"]` or the decoder model type selected is not + in `["detr"]` + + Examples: + + ```python + >>> from transformers import MaskFormerConfig, MaskFormerModel + + >>> # Initializing a MaskFormer facebook/maskformer-swin-base-ade configuration + >>> configuration = MaskFormerConfig() + + >>> # Initializing a model (with random weights) from the facebook/maskformer-swin-base-ade style configuration + >>> model = MaskFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + + """ + + model_type = "maskformer" + sub_configs = {"backbone_config": AutoConfig, "decoder_config": AutoConfig} + attribute_map = {"hidden_size": "mask_feature_size"} + backbones_supported = ["resnet", "swin"] + decoders_supported = ["detr"] + + fpn_feature_size: int = 256 + mask_feature_size: int = 256 + no_object_weight: float = 0.1 + use_auxiliary_loss: bool = False + backbone_config: dict | PreTrainedConfig | None = None + decoder_config: dict | PreTrainedConfig | None = None + init_std: float = 0.02 + init_xavier_std: float = 1.0 + dice_weight: float = 1.0 + cross_entropy_weight: float = 1.0 + mask_weight: float = 20.0 + output_auxiliary_logits: bool | None = None + + def __post_init__(self, **kwargs): + self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config( + backbone_config=self.backbone_config, + default_config_type="swin", + default_config_kwargs={ + "depths": [2, 2, 18, 2], + "drop_path_rate": 0.3, + "image_size": 384, + "embed_dim": 128, + "num_heads": [4, 8, 16, 32], + "window_size": 12, + "out_features": ["stage1", "stage2", "stage3", "stage4"], + }, + **kwargs, + ) + + # verify that the backbone is supported + if self.backbone_config is not None and self.backbone_config.model_type not in self.backbones_supported: + logger.warning_once( + f"Backbone {self.backbone_config.model_type} is not a supported model and may not be compatible with MaskFormer. " + f"Supported model types: {','.join(self.backbones_supported)}" + ) + + if self.decoder_config is None: + # fall back to https://huggingface.co/facebook/detr-resnet-50 + self.decoder_config = MaskFormerDetrConfig() + else: + # verify that the decoder is supported + decoder_type = ( + self.decoder_config.pop("model_type") + if isinstance(self.decoder_config, dict) + else self.decoder_config.model_type + ) + if decoder_type not in self.decoders_supported: + raise ValueError( + f"Transformer Decoder {decoder_type} not supported, please use one of" + f" {','.join(self.decoders_supported)}" + ) + if isinstance(self.decoder_config, dict): + config_class = CONFIG_MAPPING[decoder_type] + self.decoder_config = config_class.from_dict(self.decoder_config) + + self.num_attention_heads = self.decoder_config.encoder_attention_heads + self.num_hidden_layers = self.decoder_config.num_hidden_layers + super().__post_init__(**kwargs) + + +class DetrDecoderOutput(DetrDecoderOutput): + pass + + +@dataclass +@auto_docstring( + custom_intro=""" + MaskFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the + `encoder` and `decoder`. By default, the `encoder` is a MaskFormerSwin Transformer and the `decoder` is a Feature + Pyramid Network (FPN). + + The `encoder_last_hidden_state` are referred on the paper as **images features**, while `decoder_last_hidden_state` + as **pixel embeddings** + """ +) +class MaskFormerPixelLevelModuleOutput(ModelOutput): + r""" + encoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the encoder. + decoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the decoder. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at + the output of each stage. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at + the output of each stage. + """ + + encoder_last_hidden_state: torch.FloatTensor | None = None + decoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state + and (optionally) the hidden states. + """ +) +class MaskFormerPixelDecoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the model. + """ + + last_hidden_state: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`MaskFormerModel`]. This class returns all the needed hidden states to compute the logits. + """ +) +class MaskFormerModelOutput(ModelOutput): + r""" + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the encoder model (backbone). + pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN). + transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Last hidden states (final feature map) of the last stage of the transformer decoder model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. + hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and + `decoder_hidden_states` + """ + + encoder_last_hidden_state: torch.FloatTensor | None = None + pixel_decoder_last_hidden_state: torch.FloatTensor | None = None + transformer_decoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + pixel_decoder_hidden_states: tuple[torch.FloatTensor] | None = None + transformer_decoder_hidden_states: tuple[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`MaskFormerForInstanceSegmentation`]. + + This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or + [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or + [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see + [`~MaskFormerImageProcessor] for details regarding usage. + """ +) +class MaskFormerForInstanceSegmentationOutput(ModelOutput): + r""" + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + auxiliary_logits (`Dict[str, torch.FloatTensor]`, *optional*, returned when `output_auxiliary_logits=True`): + Dictionary containing auxiliary predictions for each decoder layer when auxiliary losses are enabled. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the encoder model (backbone). + pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN). + transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Last hidden states (final feature map) of the last stage of the transformer decoder model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the transformer decoder at the output + of each stage. + hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and + `decoder_hidden_states`. + """ + + loss: torch.FloatTensor | None = None + class_queries_logits: torch.FloatTensor | None = None + masks_queries_logits: torch.FloatTensor | None = None + auxiliary_logits: torch.FloatTensor | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + pixel_decoder_last_hidden_state: torch.FloatTensor | None = None + transformer_decoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + pixel_decoder_hidden_states: tuple[torch.FloatTensor] | None = None + transformer_decoder_hidden_states: tuple[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + + +def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor: + """ + An utility function that upsamples `pixel_values` to match the dimension of `like`. + + Args: + pixel_values (`torch.Tensor`): + The tensor we wish to upsample. + like (`torch.Tensor`): + The tensor we wish to use as size target. + mode (str, *optional*, defaults to `"bilinear"`): + The interpolation mode. + + Returns: + `torch.Tensor`: The upsampled tensor + """ + _, _, height, width = like.shape + upsampled = nn.functional.interpolate(pixel_values, size=(height, width), mode=mode, align_corners=False) + return upsampled + + +# refactored from original implementation +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +# refactored from original implementation +def sigmoid_focal_loss( + inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2 +) -> Tensor: + r""" + Focal loss proposed in [Focal Loss for Dense Object Detection](https://huggingface.co/papers/1708.02002) originally used in + RetinaNet. The loss is computed as follows: + + $$ \mathcal{L}_{\text{focal loss} = -(1 - p_t)^{\gamma}\log{(p_t)} $$ + + where \\(CE(p_t) = -\log{(p_t)}}\\), CE is the standard Cross Entropy Loss + + Please refer to equation (1,2,3) of the paper for a better understanding. + + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + alpha (float, *optional*, defaults to 0.25): + Weighting factor in range (0,1) to balance positive vs negative examples. + gamma (float, *optional*, defaults to 2.0): + Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + + Returns: + `torch.Tensor`: The computed loss. + """ + criterion = nn.BCEWithLogitsLoss(reduction="none") + probs = inputs.sigmoid() + cross_entropy_loss = criterion(inputs, labels) + p_t = probs * labels + (1 - probs) * (1 - labels) + loss = cross_entropy_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * labels + (1 - alpha) * (1 - labels) + loss = alpha_t * loss + + loss = loss.mean(1).sum() / num_masks + return loss + + +# refactored from original implementation +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * torch.matmul(inputs, labels.T) + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +# refactored from original implementation +def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor: + r""" + A pair wise version of the focal loss, see `sigmoid_focal_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha (float, *optional*, defaults to 0.25): + Weighting factor in range (0,1) to balance positive vs negative examples. + gamma (float, *optional*, defaults to 2.0): + Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + if alpha < 0: + raise ValueError("alpha must be positive") + + height_and_width = inputs.shape[1] + + criterion = nn.BCEWithLogitsLoss(reduction="none") + prob = inputs.sigmoid() + cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) + focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos + focal_pos *= alpha + + cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + + focal_neg = (prob**gamma) * cross_entropy_loss_neg + focal_neg *= 1 - alpha + + loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T) + + return loss / height_and_width + + +class MaskFormerDetrDecoder(DetrDecoder): + pass + + +# refactored from original implementation +class MaskFormerHungarianMatcher(nn.Module): + """This class computes an assignment between the labels and the predictions of the network. + + For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more + predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0): + """Creates the matcher + + Params: + cost_class (float, *optional*, defaults to 1.0): + This is the relative weight of the classification error in the matching cost. + cost_mask (float, *optional*, defaults to 1.0): + This is the relative weight of the focal loss of the binary mask in the matching cost. + cost_dice (float, *optional*, defaults to 1.0): + This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + if cost_class == 0 and cost_mask == 0 and cost_dice == 0: + raise ValueError("All costs can't be 0") + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + @torch.no_grad() + def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> list[tuple[Tensor]]: + """Performs the matching + + Params: + masks_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, num_labels` with the + classification logits. + class_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, height, width` with the + predicted masks. + + class_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes` (where num_target_boxes is the number + of ground-truth objects in the target) containing the class labels. + mask_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes, height, width` containing the target + masks. + + Returns: + `list[tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected labels (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes). + """ + indices: list[tuple[np.array]] = [] + + preds_masks = masks_queries_logits + preds_probs = class_queries_logits + # iterate through batch size + for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels): + # downsample the target mask, save memory + target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest") + pred_probs = pred_probs.softmax(-1) + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be omitted. + cost_class = -pred_probs[:, labels] + # flatten spatial dimension "q h w -> q (h w)" + pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width] + # same for target_mask "c h w -> c (h w)" + target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width] + # compute the focal loss between each mask pairs -> shape (num_queries, num_labels) + cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat) + # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels) + cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat) + # final cost matrix + cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + # do the assignment using the hungarian algorithm in scipy + assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) + indices.append(assigned_indices) + + # It could be stacked in one tensor + matched_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices + ] + return matched_indices + + def __repr__(self): + head = "Matcher " + self.__class__.__name__ + body = [ + f"cost_class: {self.cost_class}", + f"cost_mask: {self.cost_mask}", + f"cost_dice: {self.cost_dice}", + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +# copied and adapted from original implementation +class MaskFormerLoss(nn.Module): + def __init__( + self, + num_labels: int, + matcher: MaskFormerHungarianMatcher, + weight_dict: dict[str, float], + eos_coef: float, + ): + """ + The MaskFormer Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we compute + hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair of + matched ground-truth / prediction (supervise class and mask) + + Args: + num_labels (`int`): + The number of classes. + matcher (`MaskFormerHungarianMatcher`): + A torch module that computes the assignments between the predictions and labels. + weight_dict (`dict[str, float]`): + A dictionary of weights to be applied to the different losses. + eos_coef (`float`): + Weight to apply to the null class. + """ + + super().__init__() + requires_backends(self, ["scipy"]) + self.num_labels = num_labels + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + empty_weight = torch.ones(self.num_labels + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + def _max_by_axis(self, the_list: list[list[int]]) -> list[int]: + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + # get the maximum size in the batch + max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) + batch_size = len(tensors) + # compute finel size + batch_shape = [batch_size] + max_size + b, _, h, w = batch_shape + # get metadata + dtype = tensors[0].dtype + device = tensors[0].device + padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) + padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device) + # pad the tensors to the size of the biggest one + for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): + padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) + padding_mask[: tensor.shape[1], : tensor.shape[2]] = False + + return padded_tensors, padding_masks + + def loss_labels( + self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array] + ) -> dict[str, Tensor]: + """Compute the losses related to the labels using cross entropy. + + Args: + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + class_labels (`list[torch.Tensor]`): + List of class labels of shape `(labels)`. + indices (`tuple[np.array])`: + The indices computed by the Hungarian matcher. + + Returns: + `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + """ + + pred_logits = class_queries_logits + batch_size, num_queries, _ = pred_logits.shape + criterion = nn.CrossEntropyLoss(weight=self.empty_weight) + idx = self._get_predictions_permutation_indices(indices) + # shape = (batch_size, num_queries) + target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)]) + # shape = (batch_size, num_queries) + target_classes = torch.full( + (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device + ) + target_classes[idx] = target_classes_o + # target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q" + pred_logits_transposed = pred_logits.transpose(1, 2) + loss_ce = criterion(pred_logits_transposed, target_classes) + losses = {"loss_cross_entropy": loss_ce} + return losses + + def loss_masks( + self, masks_queries_logits: Tensor, mask_labels: list[Tensor], indices: tuple[np.array], num_masks: int + ) -> dict[str, Tensor]: + """Compute the losses related to the masks using focal and dice loss. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + indices (`tuple[np.array])`: + The indices computed by the Hungarian matcher. + num_masks (`int)`: + The number of masks, used for normalization. + + Returns: + `dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + """ + src_idx = self._get_predictions_permutation_indices(indices) + tgt_idx = self._get_targets_permutation_indices(indices) + # shape (batch_size * num_queries, height, width) + pred_masks = masks_queries_logits[src_idx] + # shape (batch_size, num_queries, height, width) + # pad all and stack the targets to the num_labels dimension + target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) + target_masks = target_masks[tgt_idx] + # upsample predictions to the target size, we have to add one dim to use interpolate + pred_masks = nn.functional.interpolate( + pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + pred_masks = pred_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + losses = { + "loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks), + "loss_dice": dice_loss(pred_masks, target_masks, num_masks), + } + return losses + + def _get_predictions_permutation_indices(self, indices): + # permute predictions following indices + batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + predictions_indices = torch.cat([src for (src, _) in indices]) + return batch_indices, predictions_indices + + def _get_targets_permutation_indices(self, indices): + # permute labels following indices + batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + target_indices = torch.cat([tgt for (_, tgt) in indices]) + return batch_indices, target_indices + + def forward( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + mask_labels: list[Tensor], + class_labels: list[Tensor], + auxiliary_predictions: dict[str, Tensor] | None = None, + ) -> dict[str, Tensor]: + """ + This performs the loss computation. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + class_labels (`list[torch.Tensor]`): + List of class labels of shape `(labels)`. + auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*): + if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the + inner layers of the Detr's Decoder. + + Returns: + `dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], the dictionary contains additional losses + for each auxiliary predictions. + """ + + # retrieve the matching between the outputs of the last layer and the labels + indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + # compute the average number of target masks for normalization purposes + num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device) + # get all the losses + losses: dict[str, Tensor] = { + **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), + **self.loss_labels(class_queries_logits, class_labels, indices), + } + # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if auxiliary_predictions is not None: + for idx, aux_outputs in enumerate(auxiliary_predictions): + masks_queries_logits = aux_outputs["masks_queries_logits"] + class_queries_logits = aux_outputs["class_queries_logits"] + loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} + losses.update(loss_dict) + + return losses + + def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + Computes the average number of target masks across the batch, for normalization purposes. + """ + num_masks = sum(len(classes) for classes in class_labels) + num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device) + world_size = 1 + if is_accelerate_available(): + if PartialState._shared_state != {}: + num_masks = reduce(num_masks) + world_size = PartialState().num_processes + + num_masks = torch.clamp(num_masks / world_size, min=1) + return num_masks + + +class MaskFormerFPNConvLayer(nn.Module): + def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1): + """ + A basic module that executes conv - norm - in sequence used in MaskFormer. + + Args: + in_features (`int`): + The number of input features (channels). + out_features (`int`): + The number of outputs features (channels). + """ + super().__init__() + self.layers = [ + nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False), + nn.GroupNorm(32, out_features), + nn.ReLU(inplace=True), + ] + for i, layer in enumerate(self.layers): + # Provide backwards compatibility from when the class inherited from nn.Sequential + # In nn.Sequential subclasses, the name given to the layer is its index in the sequence. + # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g. + # self.my_layer_name = Layer() + # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register + # explicitly + self.add_module(str(i), layer) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class MaskFormerFPNLayer(nn.Module): + def __init__(self, in_features: int, lateral_features: int): + """ + A Feature Pyramid Network Layer (FPN) layer. It creates a feature map by aggregating features from the previous + and backbone layer. Due to the spatial mismatch, the tensor coming from the previous layer is upsampled. + + Args: + in_features (`int`): + The number of input features (channels). + lateral_features (`int`): + The number of lateral features (channels). + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(lateral_features, in_features, kernel_size=1, padding=0, bias=False), + nn.GroupNorm(32, in_features), + ) + + self.block = MaskFormerFPNConvLayer(in_features, in_features) + + def forward(self, down: Tensor, left: Tensor) -> Tensor: + left = self.proj(left) + down = nn.functional.interpolate(down, size=left.shape[-2:], mode="nearest") + down += left + down = self.block(down) + return down + + +class MaskFormerFPNModel(nn.Module): + def __init__(self, in_features: int, lateral_widths: list[int], feature_size: int = 256): + """ + Feature Pyramid Network, given an input tensor and a set of feature map of different feature/spatial size, it + creates a list of feature maps with the same feature size. + + Args: + in_features (`int`): + The number of input features (channels). + lateral_widths (`list[int]`): + A list with the features (channels) size of each lateral connection. + feature_size (int, *optional*, defaults to 256): + The features (channels) of the resulting feature maps. + """ + super().__init__() + self.stem = MaskFormerFPNConvLayer(in_features, feature_size) + self.layers = nn.Sequential( + *[MaskFormerFPNLayer(feature_size, lateral_width) for lateral_width in lateral_widths[::-1]] + ) + + def forward(self, features: list[Tensor]) -> list[Tensor]: + fpn_features = [] + last_feature = features[-1] + other_features = features[:-1] + output = self.stem(last_feature) + for layer, left in zip(self.layers, other_features[::-1]): + output = layer(output, left) + fpn_features.append(output) + return fpn_features + + +class MaskFormerPixelDecoder(nn.Module): + def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs): + r""" + Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic + Segmentation](https://huggingface.co/papers/2107.06278). It first runs the backbone's features into a Feature Pyramid + Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`. + + Args: + feature_size (`int`, *optional*, defaults to 256): + The feature size (channel dimension) of the FPN feature maps. + mask_feature_size (`int`, *optional*, defaults to 256): + The features (channels) of the target masks size \\(C_{\epsilon}\\) in the paper. + """ + super().__init__() + + self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs) + self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1) + + def forward( + self, features: list[Tensor], output_hidden_states: bool = False, return_dict: bool = True + ) -> MaskFormerPixelDecoderOutput: + fpn_features = self.fpn(features) + # we use the last feature map + last_feature_projected = self.mask_projection(fpn_features[-1]) + + if not return_dict: + return (last_feature_projected, tuple(fpn_features)) if output_hidden_states else (last_feature_projected,) + + return MaskFormerPixelDecoderOutput( + last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else () + ) + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class MaskFormerSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + @compile_compatible_method_lru_cache(maxsize=1) + def forward( + self, + shape: torch.Size, + device: torch.device | str, + dtype: torch.dtype, + mask: Tensor | None = None, + ) -> Tensor: + if mask is None: + mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) + not_mask = (~mask).to(dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PredictionBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None: + super().__init__() + self.layers = [nn.Linear(in_dim, out_dim), activation] + # Maintain submodule indexing as if part of a Sequential block + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class MaskformerMLPPredictionHead(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): + """ + A classic Multi Layer Perceptron (MLP). + + Args: + input_dim (`int`): + The input dimensions. + hidden_dim (`int`): + The hidden dimensions. + output_dim (`int`): + The output dimensions. + num_layers (int, *optional*, defaults to 3): + The number of layers. + """ + super().__init__() + in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) + out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] + + self.layers = [] + for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): + activation = nn.ReLU() if i < num_layers - 1 else nn.Identity() + layer = PredictionBlock(in_dim, out_dim, activation=activation) + self.layers.append(layer) + # Provide backwards compatibility from when the class inherited from nn.Sequential + # In nn.Sequential subclasses, the name given to the layer is its index in the sequence. + # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g. + # self.my_layer_name = Layer() + # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register + # explicitly + self.add_module(str(i), layer) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class MaskFormerPixelLevelModule(nn.Module): + def __init__(self, config: MaskFormerConfig): + """ + Pixel Level Module proposed in [Per-Pixel Classification is Not All You Need for Semantic + Segmentation](https://huggingface.co/papers/2107.06278). It runs the input image through a backbone and a pixel + decoder, generating an image feature map and pixel embeddings. + + Args: + config ([`MaskFormerConfig`]): + The configuration used to instantiate this model. + """ + super().__init__() + if getattr(config, "backbone_config") is not None and config.backbone_config.model_type == "swin": + # for backwards compatibility + backbone_config = config.backbone_config + backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) + backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] + config.backbone_config = backbone_config + self.encoder = AutoBackbone.from_config(config=config.backbone_config) + + feature_channels = self.encoder.channels + self.decoder = MaskFormerPixelDecoder( + in_features=feature_channels[-1], + feature_size=config.fpn_feature_size, + mask_feature_size=config.mask_feature_size, + lateral_widths=feature_channels[:-1], + ) + + def forward( + self, pixel_values: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> MaskFormerPixelLevelModuleOutput: + features = self.encoder(pixel_values).feature_maps + decoder_output = self.decoder(features, output_hidden_states, return_dict=return_dict) + + if not return_dict: + last_hidden_state = decoder_output[0] + outputs = (features[-1], last_hidden_state) + if output_hidden_states: + hidden_states = decoder_output[1] + outputs = outputs + (tuple(features),) + (hidden_states,) + return outputs + + return MaskFormerPixelLevelModuleOutput( + # the last feature is actually the output from the last layer + encoder_last_hidden_state=features[-1], + decoder_last_hidden_state=decoder_output.last_hidden_state, + encoder_hidden_states=tuple(features) if output_hidden_states else (), + decoder_hidden_states=decoder_output.hidden_states if output_hidden_states else (), + ) + + +class MaskFormerTransformerModule(nn.Module): + """ + The MaskFormer's transformer module. + """ + + def __init__(self, in_features: int, config: MaskFormerConfig): + super().__init__() + hidden_size = config.decoder_config.hidden_size + should_project = in_features != hidden_size + self.position_embedder = MaskFormerSinePositionEmbedding(num_pos_feats=hidden_size // 2, normalize=True) + self.queries_embedder = nn.Embedding(config.decoder_config.num_queries, hidden_size) + self.input_projection = nn.Conv2d(in_features, hidden_size, kernel_size=1) if should_project else None + self.decoder = MaskFormerDetrDecoder(config=config.decoder_config) + + def forward( + self, + image_features: Tensor, + output_hidden_states: bool = False, + output_attentions: bool = False, + return_dict: bool | None = None, + ) -> DetrDecoderOutput: + if self.input_projection is not None: + image_features = self.input_projection(image_features) + object_queries = self.position_embedder(image_features.shape, image_features.device, image_features.dtype) + # repeat the queries "q c -> b q c" + batch_size = image_features.shape[0] + queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1) + inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=self.training) + + # torch.export.export does no support requires_grad + if self.training: + inputs_embeds.requires_grad_(True) + + batch_size, num_channels, height, width = image_features.shape + # rearrange both image_features and object_queries "b c h w -> b (h w) c" + image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1) + object_queries = object_queries.view(batch_size, num_channels, height * width).permute(0, 2, 1) + + decoder_output: DetrDecoderOutput = self.decoder( + inputs_embeds=inputs_embeds, + attention_mask=None, + encoder_hidden_states=image_features, + encoder_attention_mask=None, + spatial_position_embeddings=object_queries, + object_queries_position_embeddings=queries_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return decoder_output + + +@auto_docstring +class MaskFormerPreTrainedModel(PreTrainedModel): + config: MaskFormerConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + input_modalities = ("image",) + + @torch.no_grad() + def _init_weights(self, module: nn.Module): + xavier_std = self.config.init_xavier_std + std = self.config.init_std + if isinstance(module, MaskFormerTransformerModule): + if module.input_projection is not None: + init.xavier_uniform_(module.input_projection.weight, gain=xavier_std) + init.constant_(module.input_projection.bias, 0) + # FPN + elif isinstance(module, MaskFormerFPNModel): + init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std) + + elif isinstance(module, MaskFormerFPNLayer): + init.xavier_uniform_(module.proj[0].weight, gain=xavier_std) + + elif isinstance(module, MaskFormerFPNConvLayer): + init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std) + # The MLP head + elif isinstance(module, MaskformerMLPPredictionHead): + # I was not able to find the correct initializer in the original implementation + # we'll use xavier + for submodule in module.modules(): + if isinstance(submodule, nn.Linear): + init.xavier_uniform_(submodule.weight, gain=xavier_std) + init.constant_(submodule.bias, 0) + elif isinstance(module, nn.LayerNorm): + init.zeros_(module.bias) + init.ones_(module.weight) + # copied from DETR + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + init.zeros_(module.bias) + if getattr(module, "running_mean", None) is not None: + init.zeros_(module.running_mean) + init.ones_(module.running_var) + init.zeros_(module.num_batches_tracked) + elif isinstance(module, nn.Embedding): + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) + elif isinstance(module, MaskFormerLoss): + empty_weight = torch.ones(module.num_labels + 1) + empty_weight[-1] = module.eos_coef + init.copy_(module.empty_weight, empty_weight) + + +@auto_docstring +class MaskFormerModel(MaskFormerPreTrainedModel): + def __init__(self, config: MaskFormerConfig): + super().__init__(config) + self.pixel_level_module = MaskFormerPixelLevelModule(config) + self.transformer_module = MaskFormerTransformerModule( + in_features=self.pixel_level_module.encoder.channels[-1], config=config + ) + + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Tensor, + pixel_mask: Tensor | None = None, + output_hidden_states: bool | None = None, + output_attentions: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> MaskFormerModelOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, MaskFormerModel + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + + >>> # load MaskFormer fine-tuned on ADE20k semantic segmentation + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-ade") + >>> model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-base-ade") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + + >>> inputs = image_processor(image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # the decoder of MaskFormer outputs hidden states of shape (batch_size, num_queries, hidden_size) + >>> transformer_decoder_last_hidden_state = outputs.transformer_decoder_last_hidden_state + >>> list(transformer_decoder_last_hidden_state.shape) + [1, 100, 256] + ```""" + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + pixel_level_module_output = self.pixel_level_module( + pixel_values, output_hidden_states, return_dict=return_dict + ) + image_features = pixel_level_module_output[0] + pixel_embeddings = pixel_level_module_output[1] + + transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions) + queries = transformer_module_output.last_hidden_state + + encoder_hidden_states = None + pixel_decoder_hidden_states = None + transformer_decoder_hidden_states = None + hidden_states = None + + if output_hidden_states: + encoder_hidden_states = pixel_level_module_output[2] + pixel_decoder_hidden_states = pixel_level_module_output[3] + transformer_decoder_hidden_states = transformer_module_output[1] + hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states + + output = MaskFormerModelOutput( + encoder_last_hidden_state=image_features, + pixel_decoder_last_hidden_state=pixel_embeddings, + transformer_decoder_last_hidden_state=queries, + encoder_hidden_states=encoder_hidden_states, + pixel_decoder_hidden_states=pixel_decoder_hidden_states, + transformer_decoder_hidden_states=transformer_decoder_hidden_states, + hidden_states=hidden_states, + attentions=transformer_module_output.attentions, + ) + + if not return_dict: + output = tuple(v for v in output.values()) + + return output + + +class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): + def __init__(self, config: MaskFormerConfig): + super().__init__(config) + self.model = MaskFormerModel(config) + hidden_size = config.decoder_config.hidden_size + # + 1 because we add the "null" class + self.class_predictor = nn.Linear(hidden_size, config.num_labels + 1) + self.mask_embedder = MaskformerMLPPredictionHead(hidden_size, hidden_size, config.mask_feature_size) + + self.matcher = MaskFormerHungarianMatcher( + cost_class=1.0, cost_dice=config.dice_weight, cost_mask=config.mask_weight + ) + + self.weight_dict: dict[str, float] = { + "loss_cross_entropy": config.cross_entropy_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + } + + self.criterion = MaskFormerLoss( + config.num_labels, + matcher=self.matcher, + weight_dict=self.weight_dict, + eos_coef=config.no_object_weight, + ) + + self.post_init() + + def get_loss_dict( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + mask_labels: Tensor, + class_labels: Tensor, + auxiliary_logits: dict[str, Tensor], + ) -> dict[str, Tensor]: + loss_dict: dict[str, Tensor] = self.criterion( + masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits + ) + # weight each loss by `self.weight_dict[]` including auxiliary losses + for key, weight in self.weight_dict.items(): + for loss_key, loss in loss_dict.items(): + if key in loss_key: + loss *= weight + + return loss_dict + + def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor: + return sum(loss_dict.values()) + + def get_logits(self, outputs: MaskFormerModelOutput) -> tuple[Tensor, Tensor, dict[str, Tensor]]: + pixel_embeddings = outputs.pixel_decoder_last_hidden_state + # get the auxiliary predictions (one for each decoder's layer) + auxiliary_logits: list[str, Tensor] = [] + + # This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list + if self.config.use_auxiliary_loss: + stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states) + classes = self.class_predictor(stacked_transformer_decoder_outputs) + class_queries_logits = classes[-1] + # get the masks + mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs) + binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings) + + masks_queries_logits = binaries_masks[-1] + # go til [:-1] because the last one is always used + for aux_binary_masks, aux_classes in zip(binaries_masks[:-1], classes[:-1]): + auxiliary_logits.append( + {"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes} + ) + + else: + transformer_decoder_hidden_states = outputs.transformer_decoder_last_hidden_state + classes = self.class_predictor(transformer_decoder_hidden_states) + class_queries_logits = classes + # get the masks + mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states) + # sum up over the channels + masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings) + + return class_queries_logits, masks_queries_logits, auxiliary_logits + + @auto_docstring + def forward( + self, + pixel_values: Tensor, + mask_labels: list[Tensor] | None = None, + class_labels: list[Tensor] | None = None, + pixel_mask: Tensor | None = None, + output_auxiliary_logits: bool | None = None, + output_hidden_states: bool | None = None, + output_attentions: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> MaskFormerForInstanceSegmentationOutput: + r""" + mask_labels (`list[torch.Tensor]`, *optional*): + List of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`list[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + output_auxiliary_logits (`bool`, *optional*): + Whether or not to output auxiliary logits. + + Examples: + + Semantic segmentation example: + + ```python + >>> from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + + >>> # load MaskFormer fine-tuned on ADE20k semantic segmentation + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-ade") + >>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade") + + >>> url = ( + ... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" + ... ) + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to image_processor for postprocessing + >>> predicted_semantic_map = image_processor.post_process_semantic_segmentation( + ... outputs, target_sizes=[(image.height, image.width)] + ... )[0] + + >>> # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs) + >>> list(predicted_semantic_map.shape) + [512, 683] + ``` + + Panoptic segmentation example: + + ```python + >>> from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + + >>> # load MaskFormer fine-tuned on COCO panoptic segmentation + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco") + >>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-coco") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to image_processor for postprocessing + >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.height, image.width)])[0] + + >>> # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs) + >>> predicted_panoptic_map = result["segmentation"] + >>> list(predicted_panoptic_map.shape) + [480, 640] + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + raw_outputs = self.model( + pixel_values, + pixel_mask, + output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, + return_dict=return_dict, + output_attentions=output_attentions, + ) + # We need to have raw_outputs optionally be returned as a dict to use torch.compile. For backwards + # compatibility we convert to a dataclass for the rest of the model logic + outputs = MaskFormerModelOutput( + encoder_last_hidden_state=raw_outputs[0], + pixel_decoder_last_hidden_state=raw_outputs[1], + transformer_decoder_last_hidden_state=raw_outputs[2], + encoder_hidden_states=raw_outputs[3] if output_hidden_states else None, + pixel_decoder_hidden_states=raw_outputs[4] if output_hidden_states else None, + transformer_decoder_hidden_states=raw_outputs[5] if output_hidden_states else None, + hidden_states=raw_outputs[6] if output_hidden_states else None, + attentions=raw_outputs[-1] if output_attentions else None, + ) + + loss, loss_dict, auxiliary_logits = None, None, None + + class_queries_logits, masks_queries_logits, auxiliary_logits = self.get_logits(outputs) + + if mask_labels is not None and class_labels is not None: + loss_dict: dict[str, Tensor] = self.get_loss_dict( + masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits + ) + loss = self.get_loss(loss_dict) + + output_auxiliary_logits = ( + self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits + ) + if not output_auxiliary_logits: + auxiliary_logits = None + + if not return_dict: + output = tuple( + v + for v in (loss, class_queries_logits, masks_queries_logits, auxiliary_logits, *outputs.values()) + if v is not None + ) + return output + + return MaskFormerForInstanceSegmentationOutput( + loss=loss, + **outputs, + class_queries_logits=class_queries_logits, + masks_queries_logits=masks_queries_logits, + auxiliary_logits=auxiliary_logits, + ) + + +__all__ = [ + "MaskFormerConfig", + "MaskFormerDetrConfig", + "MaskFormerForInstanceSegmentation", + "MaskFormerModel", + "MaskFormerPreTrainedModel", + "MaskFormerDetrPreTrainedModel", # noqa F821 +] diff --git a/tests/models/conditional_detr/test_modeling_conditional_detr.py b/tests/models/conditional_detr/test_modeling_conditional_detr.py index eabbe9194fdb..e2eeec9bfdfa 100644 --- a/tests/models/conditional_detr/test_modeling_conditional_detr.py +++ b/tests/models/conditional_detr/test_modeling_conditional_detr.py @@ -16,19 +16,25 @@ import copy import inspect import math +import os +import re +import tempfile import unittest from functools import cached_property from transformers import ConditionalDetrConfig, ResNetConfig, is_torch_available, is_vision_available +from transformers.conversion_mapping import get_model_conversion_mapping +from transformers.core_model_loading import WeightRenaming, process_target_pattern from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_modeling_common import ModelTesterMixin, compare_state_dicts, floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch + from safetensors.torch import load_file from transformers import ( ConditionalDetrForObjectDetection, @@ -234,6 +240,88 @@ def test_conditional_detr_object_detection_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs) + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + # Some conversions from the mapping are specific to `DetrForSegmentation` model only + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at + # lest one MoE layer here to check the mapping + config_to_set = config.get_text_config(decoder=True) + config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE + config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5... + config_to_set.mlp_only_layers = [0] # same but for qwens + config_to_set.num_dense_layers = 1 # lfm2_moe + + for model_class in self.all_model_classes: + # Each individual model is a subtest + with self.subTest(model_class.__name__): + model = model_class(copy.deepcopy(config)) + # Skip if no conversions + conversions = get_model_conversion_mapping(model, add_legacy=False) + if len(conversions) == 0: + # No conversion mapping for this model only, needs to test other classes + continue + + # Find the model keys, so the targets according to the conversions + model_keys = list(model.state_dict().keys()) + + with tempfile.TemporaryDirectory() as tmpdirname: + # Serialize with reverse mapping + model.save_pretrained(tmpdirname) + state_dict = load_file(os.path.join(tmpdirname, "model.safetensors")) + # Get all the serialized keys that we just saved according to the reverse mapping + serialized_keys = list(state_dict.keys()) + + if check_keys_were_modified: + # They should be different, otherwise we did not perform any mapping + self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") + + # Check that for each conversion entry, we at least map to one key + for conversion in conversions: + for source_pattern in conversion.source_patterns: + # Sometimes the mappings specify keys that are tied, so absent from the saved state dict + if isinstance(conversion, WeightRenaming): + # We need to revert the target pattern to make it compatible with regex search + target_pattern_reversed = conversion.target_patterns[0] + captured_group = process_target_pattern(source_pattern)[1] + if captured_group: + target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group) + if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()): + continue + num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys) + + # DIFF FROM MIXIN IS HERE + if ( + "bbox" in source_pattern or "mask_head" in source_pattern + ) and model_class != ConditionalDetrForSegmentation: + pass + else: + self.assertTrue( + num_matches > 0, + f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " + "This indicates whether that the pattern is not properly written, or that it could not be reversed correctly", + ) + + # If everything is still good at this point, let's test that we perform the same operations both when + # reverting ops from `from_pretrained` and from `__init__` + with tempfile.TemporaryDirectory() as tmpdirname: + # The model was instantiated from __init__ before being saved + model.save_pretrained(tmpdirname) + state_dict_saved_from_init = load_file(os.path.join(tmpdirname, "model.safetensors")) + + # Now reload it + model_reloaded = model_class.from_pretrained(tmpdirname) + + # Make sure both loaded state_dict are identical + self.assertTrue(compare_state_dicts(model_reloaded.state_dict(), model.state_dict())) + + # The model was instantiated from `from_pretrained` before being saved + model_reloaded.save_pretrained(tmpdirname) + state_dict_saved_from_pretrained = load_file(os.path.join(tmpdirname, "model.safetensors")) + + # Make sure both saved state_dict are identical + self.assertTrue(compare_state_dicts(state_dict_saved_from_init, state_dict_saved_from_pretrained)) + # TODO: check if this works again for PyTorch 2.x.y @unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.") def test_multi_gpu_data_parallel_forward(self): diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index 2943ef755e34..c4baec276f4f 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -16,12 +16,17 @@ import copy import inspect import math +import os +import re +import tempfile import unittest from functools import cached_property from parameterized import parameterized from transformers import DetrConfig, ResNetConfig, is_torch_available, is_vision_available +from transformers.conversion_mapping import get_model_conversion_mapping +from transformers.core_model_loading import WeightRenaming, process_target_pattern from transformers.testing_utils import Expectations, require_timm, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester @@ -29,6 +34,7 @@ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, ModelTesterMixin, _test_eager_matches_sdpa_inference, + compare_state_dicts, floats_tensor, ) from ...test_pipeline_mixin import PipelineTesterMixin @@ -36,6 +42,7 @@ if is_torch_available(): import torch + from safetensors.torch import load_file from transformers import DetrForObjectDetection, DetrForSegmentation, DetrModel @@ -199,6 +206,88 @@ class DetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): test_missing_keys = False zero_init_hidden_state = True + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + # Some conversions from the mapping are specific to `DetrForSegmentation` model only + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at + # lest one MoE layer here to check the mapping + config_to_set = config.get_text_config(decoder=True) + config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE + config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5... + config_to_set.mlp_only_layers = [0] # same but for qwens + config_to_set.num_dense_layers = 1 # lfm2_moe + + for model_class in self.all_model_classes: + # Each individual model is a subtest + with self.subTest(model_class.__name__): + model = model_class(copy.deepcopy(config)) + # Skip if no conversions + conversions = get_model_conversion_mapping(model, add_legacy=False) + if len(conversions) == 0: + # No conversion mapping for this model only, needs to test other classes + continue + + # Find the model keys, so the targets according to the conversions + model_keys = list(model.state_dict().keys()) + + with tempfile.TemporaryDirectory() as tmpdirname: + # Serialize with reverse mapping + model.save_pretrained(tmpdirname) + state_dict = load_file(os.path.join(tmpdirname, "model.safetensors")) + # Get all the serialized keys that we just saved according to the reverse mapping + serialized_keys = list(state_dict.keys()) + + if check_keys_were_modified: + # They should be different, otherwise we did not perform any mapping + self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") + + # Check that for each conversion entry, we at least map to one key + for conversion in conversions: + for source_pattern in conversion.source_patterns: + # Sometimes the mappings specify keys that are tied, so absent from the saved state dict + if isinstance(conversion, WeightRenaming): + # We need to revert the target pattern to make it compatible with regex search + target_pattern_reversed = conversion.target_patterns[0] + captured_group = process_target_pattern(source_pattern)[1] + if captured_group: + target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group) + if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()): + continue + num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys) + + # DIFF FROM MIXIN IS HERE + if ( + "bbox" in source_pattern or "mask_head" in source_pattern + ) and model_class != DetrForSegmentation: + pass + else: + self.assertTrue( + num_matches > 0, + f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " + "This indicates whether that the pattern is not properly written, or that it could not be reversed correctly", + ) + + # If everything is still good at this point, let's test that we perform the same operations both when + # reverting ops from `from_pretrained` and from `__init__` + with tempfile.TemporaryDirectory() as tmpdirname: + # The model was instantiated from __init__ before being saved + model.save_pretrained(tmpdirname) + state_dict_saved_from_init = load_file(os.path.join(tmpdirname, "model.safetensors")) + + # Now reload it + model_reloaded = model_class.from_pretrained(tmpdirname) + + # Make sure both loaded state_dict are identical + self.assertTrue(compare_state_dicts(model_reloaded.state_dict(), model.state_dict())) + + # The model was instantiated from `from_pretrained` before being saved + model_reloaded.save_pretrained(tmpdirname) + state_dict_saved_from_pretrained = load_file(os.path.join(tmpdirname, "model.safetensors")) + + # Make sure both saved state_dict are identical + self.assertTrue(compare_state_dicts(state_dict_saved_from_init, state_dict_saved_from_pretrained)) + # special case for head models def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) diff --git a/tests/models/edgetam/test_modeling_edgetam.py b/tests/models/edgetam/test_modeling_edgetam.py index 7cddae0337eb..c0f25af5e888 100644 --- a/tests/models/edgetam/test_modeling_edgetam.py +++ b/tests/models/edgetam/test_modeling_edgetam.py @@ -248,6 +248,10 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() + @unittest.skip(reason="The model has TimmWrapper backbone but doesn't apply any conversion") + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + @unittest.skip(reason="Timm model does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index 20391de210fc..f66f27b003bc 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -228,6 +228,10 @@ def test_can_be_initialized_on_meta(self): def test_get_image_features_attentions(self): pass + @unittest.skip(reason="The model has TimmWrapper backbone but doesn't apply any conversion") + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + def _image_features_get_expected_num_hidden_states(self, model_tester=None): # For models that rely on timm for their vision backend, it's hard to infer how many layers the model has # from the timm config alone. So, we're just hardcoding the expected number of hidden states here. diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py index 753c4ebce15b..430d728727bc 100644 --- a/tests/models/maskformer/test_modeling_maskformer.py +++ b/tests/models/maskformer/test_modeling_maskformer.py @@ -245,6 +245,12 @@ def test_maskformer_instance_segmentation_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_maskformer_instance_segmentation_head_model(*config_and_inputs) + @unittest.skip( + reason="MaskFormer loads only the decoder from DETR and needs 2-3 conversion from the whole mapping" + ) + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + @unittest.skip(reason="MaskFormer does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/pe_audio_video/test_modeling_pe_audio_video.py b/tests/models/pe_audio_video/test_modeling_pe_audio_video.py index 63f7399f2596..92860bf9a7c1 100644 --- a/tests/models/pe_audio_video/test_modeling_pe_audio_video.py +++ b/tests/models/pe_audio_video/test_modeling_pe_audio_video.py @@ -225,6 +225,10 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip(reason="The model has TimmWrapper backbone but doesn't apply any conversion") + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + @unittest.skip(reason="PeAudioVideoEncoder does not have usual input embeddings") def test_model_get_set_embeddings(self): pass diff --git a/tests/models/pe_video/test_modeling_pe_video.py b/tests/models/pe_video/test_modeling_pe_video.py index fdab3c353a89..d49391a33c3d 100644 --- a/tests/models/pe_video/test_modeling_pe_video.py +++ b/tests/models/pe_video/test_modeling_pe_video.py @@ -149,6 +149,10 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip(reason="The model has TimmWrapper backbone but doesn't apply any conversion") + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + @unittest.skip(reason="Timm Eva (PE) weights cannot be fully constructed in _init_weights") def test_can_init_all_missing_weights(self): pass @@ -327,6 +331,10 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip(reason="The model has TimmWrapper backbone but doesn't apply any conversion") + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + @unittest.skip(reason="PeVideoModel does not have usual input embeddings") def test_model_get_set_embeddings(self): pass diff --git a/tests/models/perception_lm/test_modeling_perception_lm.py b/tests/models/perception_lm/test_modeling_perception_lm.py index 99ab7bee997b..75a7f9bcee81 100644 --- a/tests/models/perception_lm/test_modeling_perception_lm.py +++ b/tests/models/perception_lm/test_modeling_perception_lm.py @@ -289,6 +289,10 @@ def test_training_gradient_checkpointing_use_reentrant_true(self): self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else () super().test_training_gradient_checkpointing_use_reentrant_true() + @unittest.skip(reason="The model has TimmWrapper backbone but doesn't apply any conversion") + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + @unittest.skip( reason="PE/TIMM's attention implementation is self configured and won't raise ValueError on global attention implementation." ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 13b81855aaa6..e17511a5fc42 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4715,7 +4715,7 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True): # Skip if no conversions conversions = get_model_conversion_mapping(model, add_legacy=False) if len(conversions) == 0: - self.skipTest("No conversion found for this model") + self.skipTest(f"No conversion found for {model_class}") # Find the model keys, so the targets according to the conversions model_keys = list(model.state_dict().keys()) @@ -4754,7 +4754,7 @@ def test_reverse_loading_mapping(self, check_keys_were_modified=True): self.assertTrue( num_matches > 0, f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " - "This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly", + "This indicates whether that the pattern is not properly written, or that it could not be reversed correctly", ) # If everything is still good at this point, let's test that we perform the same operations both when @@ -4791,7 +4791,7 @@ def test_can_load_from_already_mapped_keys(self): # Skip if no conversions conversions = get_model_conversion_mapping(model, add_legacy=False) if len(conversions) == 0: - self.skipTest("No conversion found for this model") + self.skipTest(f"No conversion found for {model_class}") with tempfile.TemporaryDirectory() as tmpdirname: # Serialize without reverting the mapping diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index f97168125564..ff8b3f24285d 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -99,6 +99,7 @@ "ConditionalDetrConfig": True, "DabDetrConfig": True, "SwitchTransformersConfig": True, + "MaskFormerDetrConfig": True, "DetrConfig": True, "DFineConfig": True, "GroundingDinoConfig": True,