Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions docs/source/index.rst

Large diffs are not rendered by default.

32 changes: 17 additions & 15 deletions docs/source/model_doc/detr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ than usual, but with a smaller :obj:`d_model` (which in NLP is typically 768 or
Next, this is sent through the encoder, outputting :obj:`encoder_hidden_states` of the same shape (you can consider
these as image features). Next, so-called **object queries** are sent through the decoder. This is a tensor of shape
:obj:`(batch_size, num_queries, d_model)`, with :obj:`num_queries` typically set to 100 and initialized with zeros.
These input embeddings are learnt positional encodings that the authors refer to as object queries, and similarly to the
encoder, they are added to the input of each attention layer. Each object query will look for a particular object in the
image. The decoder updates these embeddings through multiple self-attention and encoder-decoder attention layers to output
:obj:`decoder_hidden_states` of the same shape: :obj:`(batch_size, num_queries, d_model)`. Next, two heads are added on top
for object detection: a linear layer for classifying each object query into one of the objects or "no object", and a MLP to
predict bounding boxes for each query.
These input embeddings are learnt positional encodings that the authors refer to as object queries, and similarly to
the encoder, they are added to the input of each attention layer. Each object query will look for a particular object
in the image. The decoder updates these embeddings through multiple self-attention and encoder-decoder attention layers
to output :obj:`decoder_hidden_states` of the same shape: :obj:`(batch_size, num_queries, d_model)`. Next, two heads
are added on top for object detection: a linear layer for classifying each object query into one of the objects or "no
object", and a MLP to predict bounding boxes for each query.

The model is trained using a **bipartite matching loss**: so what we actually do is compare the predicted classes +
bounding boxes of each of the N = 100 object queries to the ground truth annotations, padded up to the same length N
Expand Down Expand Up @@ -89,15 +89,17 @@ Tips:
`num_boxes` variable in the `SetCriterion` class of `modeling_detr.py`. When training on multiple nodes, this should
be set to the average number of target boxes across all nodes, as can be seen in the original implementation `here
<https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/models/detr.py#L227-L232>`__.
- :class:`~transformers.DetrForObjectDetection` can be initialized with any convolutional backbone available in the `timm
library <https://github.com/rwightman/pytorch-image-models>`__. Initializing with a MobileNet backbone for example can be
done by setting the :obj:`backbone` attribute of :class:`~transformers.DetrConfig` to :obj:`"tf_mobilenetv3_small_075"`,
and then initializing :class:`~transformers.DetrForObjectDetection` with that config.
- At inference time, DETR resizes the input images such that the shortest side is at least 800 pixels while the longest at most
1333 pixels. One can use :class:`~transformers.DetrFeatureExtractor` to prepare images (and optional annotations in COCO format)
for the model. Due to this, images in a batch can have different sizes. DETR solves this by padding images up to the largest
size in a batch, and by creating a pixel mask that indicates which pixels are real/which are padding. Alternatively, one can also
define a custom :obj:`collate_fn` in order to batch images together, using :meth:`~transformers.DetrFeatureExtractor.pad_and_create_pixel_mask`.
- :class:`~transformers.DetrForObjectDetection` can be initialized with any convolutional backbone available in the
`timm library <https://github.com/rwightman/pytorch-image-models>`__. Initializing with a MobileNet backbone for
example can be done by setting the :obj:`backbone` attribute of :class:`~transformers.DetrConfig` to
:obj:`"tf_mobilenetv3_small_075"`, and then initializing :class:`~transformers.DetrForObjectDetection` with that
config.
- At inference time, DETR resizes the input images such that the shortest side is at least 800 pixels while the longest
at most 1333 pixels. One can use :class:`~transformers.DetrFeatureExtractor` to prepare images (and optional
annotations in COCO format) for the model. Due to this, images in a batch can have different sizes. DETR solves this
by padding images up to the largest size in a batch, and by creating a pixel mask that indicates which pixels are
real/which are padding. Alternatively, one can also define a custom :obj:`collate_fn` in order to batch images
together, using :meth:`~transformers.DetrFeatureExtractor.pad_and_create_pixel_mask`.

DetrConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,7 @@
is_sklearn_available,
is_speech_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_tpu_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class DetrConfig(PretrainedConfig):
just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
init_xavier_std (:obj:`float`, `optional`, defaults to 1.):
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
init_xavier_std=1.0,
classifier_dropout=0.0,
scale_embedding=False,
auxiliary_loss=False,
Expand Down Expand Up @@ -176,6 +179,7 @@ def __init__(
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.init_xavier_std = init_xavier_std
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specific variable to ensure we can control these initialization as well

self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/models/detr/feature_extraction_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ def __call__(
:class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields:

- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when :obj:`pad_and_return_pixel_mask=True` or if `"pixel_mask"`
is in :obj:`self.model_input_names`).
- **pixel_mask** -- Pixel mask to be fed to a model (when :obj:`pad_and_return_pixel_mask=True` or if
`"pixel_mask"` is in :obj:`self.model_input_names`).
"""
# Input type checking for clearer error

Expand Down Expand Up @@ -623,9 +623,9 @@ def _max_by_axis(self, the_list):
maxes[index] = max(maxes[index], item)
return maxes

def pad_and_create_pixel_mask(self,
pixel_values_list: List[torch.Tensor],
return_tensors: Optional[Union[str, TensorType]] = None):
def pad_and_create_pixel_mask(
self, pixel_values_list: List[torch.Tensor], return_tensors: Optional[Union[str, TensorType]] = None
):
"""
Pad images up to the largest image in a batch and create a corresponding :obj:`pixel_mask`.

Expand All @@ -641,11 +641,11 @@ def pad_and_create_pixel_mask(self,
:class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields:

- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when :obj:`pad_and_return_pixel_mask=True` or if `"pixel_mask"`
is in :obj:`self.model_input_names`).
- **pixel_mask** -- Pixel mask to be fed to a model (when :obj:`pad_and_return_pixel_mask=True` or if
`"pixel_mask"` is in :obj:`self.model_input_names`).

"""

max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
c, h, w = max_size
padded_images = []
Expand Down
89 changes: 81 additions & 8 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class DetrObjectDetectionOutput(ModelOutput):
Optional, only returned when auxilary losses are activated (i.e. :obj:`config.auxiliary_loss` is set to
`True`) and labels are provided. It is a list of dictionnaries containing the two above keys (:obj:`logits`
and :obj:`pred_boxes`) for each decoder layer.
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.

If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of
Expand Down Expand Up @@ -129,6 +134,7 @@ class DetrObjectDetectionOutput(ModelOutput):
logits: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
auxiliary_outputs: Optional[List[Dict]] = None
last_hidden_state: Optional[torch.FloatTensor] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
Expand All @@ -138,16 +144,72 @@ class DetrObjectDetectionOutput(ModelOutput):


@dataclass
class DetrForSegmentationOutput(DetrObjectDetectionOutput):
class DetrForSegmentationOutput(ModelOutput):
"""
This class adds one attribute to DetrObjectDetectionOutput, namely predicted masks.
Output type of :class:`~transformers.DetrForSegmentation`.

Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` are provided)):
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
scale-invariant IoU loss.
loss_dict (:obj:`Dict`, `optional`):
A dictionary containing the individual losses.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, num_classes + 1)`):
Classification logits (including no-object) for all queries.
pred_boxes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use :class:`~transformers.DetrForObjectDetection.post_process` to retrieve the
unnormalized bounding boxes.
pred_masks (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, width, height)`):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pred_mask was put at the end of the output, which isn't what we want

...
auxiliary_outputs (:obj:`list[Dict]`, `optional`):
Optional, only returned when auxilary losses are activated (i.e. :obj:`config.auxiliary_loss` is set to
`True`) and labels are provided. It is a list of dictionnaries containing the two above keys (:obj:`logits`
and :obj:`pred_boxes`) for each decoder layer.
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.

If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of
each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to
compute the weighted average in the self-attention heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the
attention softmax, used to compute the weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of
each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to
compute the weighted average in the self-attention heads.
"""

loss: Optional[torch.FloatTensor] = None
loss_dict: Optional[Dict] = None
logits: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
pred_masks: torch.FloatTensor = None
auxiliary_outputs: Optional[List[Dict]] = None
last_hidden_state: Optional[torch.FloatTensor] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None


# BELOW: utilities copied from
Expand Down Expand Up @@ -676,7 +738,9 @@ class DetrPreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
Expand Down Expand Up @@ -1412,6 +1476,7 @@ class labels themselves should be a :obj:`torch.LongTensor` of len :obj:`(number
logits=logits,
pred_boxes=pred_boxes,
auxiliary_outputs=auxiliary_outputs,
last_hidden_state=outputs.last_hidden_state,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
Expand All @@ -1424,7 +1489,7 @@ class labels themselves should be a :obj:`torch.LongTensor` of len :obj:`(number
@add_start_docstrings(
"""
DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
such as COCO panoptic.
such as COCO panoptic.

""",
DETR_START_DOCSTRING,
Expand All @@ -1439,9 +1504,16 @@ def __init__(self, config: DetrConfig):

# segmentation head
hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
self.bbox_attention = DetrMHAttentionMap(hidden_size, hidden_size, number_of_heads, dropout=0.0)
self.mask_head = DetrMaskHeadSmallConv(hidden_size + number_of_heads, [1024, 512, 256], hidden_size)

self.init_weights()

# The DetrMHAttentionMap has a custom layer initialization scheme which must not get overwritten by the
# self.init_weights()
self.bbox_attention = DetrMHAttentionMap(
hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
)

@add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DetrForSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down Expand Up @@ -1622,6 +1694,7 @@ def forward(
pred_boxes=pred_boxes,
pred_masks=pred_masks,
auxiliary_outputs=auxiliary_outputs,
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
Expand Down Expand Up @@ -1717,7 +1790,7 @@ def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
class DetrMHAttentionMap(nn.Module):
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""

def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
super().__init__()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
Expand All @@ -1728,8 +1801,8 @@ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):

nn.init.zeros_(self.k_linear.bias)
nn.init.zeros_(self.q_linear.bias)
nn.init.xavier_uniform_(self.k_linear.weight)
nn.init.xavier_uniform_(self.q_linear.weight)
nn.init.xavier_uniform_(self.k_linear.weight, gain=std)
nn.init.xavier_uniform_(self.q_linear.weight, gain=std)
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

def forward(self, q, k, mask: Optional[Tensor] = None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_feature_extraction_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_equivalence_pad_and_create_pixel_mask(self):
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)

# Test whether the method "pad_and_return_pixel_mask" and calling the feature extractor return the same tensors
encoded_images_with_method = feature_extractor_1.pad_and_create_pixel_mask(image_inputs, return_tensors="pt")
encoded_images = feature_extractor_2(image_inputs, return_tensors="pt")
Expand Down
Loading