Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/transformers/modeling_tf_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,3 +949,30 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None


@dataclass
class TFMaskedImageCompletionOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.

Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Reconstruction loss.
reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
feature maps) of the model at the output of each stage.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[tf.Tensor] = None
reconstruction: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
15 changes: 10 additions & 5 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageCompletionOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
Expand Down Expand Up @@ -588,7 +593,7 @@ def __init__(self, config: DeiTConfig) -> None:
self.post_init()

@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
Expand All @@ -597,7 +602,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]:
) -> Union[tuple, MaskedImageCompletionOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Expand Down Expand Up @@ -666,9 +671,9 @@ def forward(
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output

return MaskedLMOutput(
return MaskedImageCompletionOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Expand Down
16 changes: 9 additions & 7 deletions src/transformers/models/deit/modeling_tf_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFImageClassifierOutput,
TFMaskedLMOutput,
TFMaskedImageCompletionOutput,
)
from ...modeling_tf_utils import (
TFPreTrainedModel,
Expand Down Expand Up @@ -765,7 +765,7 @@ def __init__(self, config: DeiTConfig) -> None:

@unpack_inputs
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=TFMaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
Expand All @@ -775,7 +775,7 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[tuple, TFMaskedLMOutput]:
) -> Union[tuple, TFMaskedImageCompletionOutput]:
r"""
bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Expand Down Expand Up @@ -856,18 +856,20 @@ def call(
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output

return TFMaskedLMOutput(
return TFMaskedImageCompletionOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
def serving_output(self, output: TFMaskedImageCompletionOutput) -> TFMaskedImageCompletionOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
return TFMaskedImageCompletionOutput(
reconstruction=output.reconstruction, hidden_states=hidden_states, attentions=attentions
)


@add_start_docstrings(
Expand Down
4 changes: 2 additions & 2 deletions tests/models/deit/test_modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)

# test greyscale images
Expand All @@ -156,7 +156,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label

pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))

def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
Expand Down
4 changes: 2 additions & 2 deletions tests/models/deit/test_modeling_tf_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label
model = TFDeiTForMaskedImageModeling(config=config)
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)

# test greyscale images
Expand All @@ -139,7 +139,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label

pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))

def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
Expand Down