diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 0fed3e78511a..fb61b36f4ca4 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -949,3 +949,30 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput): loss: Optional[tf.Tensor] = None logits: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor, ...]] = None + + +@dataclass +class TFMaskedImageCompletionOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Reconstruction loss. + reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[tf.Tensor] = None + reconstruction: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index f05b16efe7a0..76a07d65467e 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -26,7 +26,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + MaskedImageCompletionOutput, +) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -588,7 +593,7 @@ def __init__(self, config: DeiTConfig) -> None: self.post_init() @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -597,7 +602,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, MaskedLMOutput]: + ) -> Union[tuple, MaskedImageCompletionOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -666,9 +671,9 @@ def forward( output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return MaskedLMOutput( + return MaskedImageCompletionOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/src/transformers/models/deit/modeling_tf_deit.py b/src/transformers/models/deit/modeling_tf_deit.py index 161f2518d068..8d4150f65b85 100644 --- a/src/transformers/models/deit/modeling_tf_deit.py +++ b/src/transformers/models/deit/modeling_tf_deit.py @@ -27,7 +27,7 @@ TFBaseModelOutput, TFBaseModelOutputWithPooling, TFImageClassifierOutput, - TFMaskedLMOutput, + TFMaskedImageCompletionOutput, ) from ...modeling_tf_utils import ( TFPreTrainedModel, @@ -765,7 +765,7 @@ def __init__(self, config: DeiTConfig) -> None: @unpack_inputs @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=TFMaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC) def call( self, pixel_values: Optional[tf.Tensor] = None, @@ -775,7 +775,7 @@ def call( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, - ) -> Union[tuple, TFMaskedLMOutput]: + ) -> Union[tuple, TFMaskedImageCompletionOutput]: r""" bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -856,18 +856,20 @@ def call( output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return TFMaskedLMOutput( + return TFMaskedImageCompletionOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: + def serving_output(self, output: TFMaskedImageCompletionOutput) -> TFMaskedImageCompletionOutput: hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None - return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions) + return TFMaskedImageCompletionOutput( + reconstruction=output.reconstruction, hidden_states=hidden_states, attentions=attentions + ) @add_start_docstrings( diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index fbf6d7353b07..1564b23aa659 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -145,7 +145,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label model.eval() result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) ) # test greyscale images @@ -156,7 +156,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size diff --git a/tests/models/deit/test_modeling_tf_deit.py b/tests/models/deit/test_modeling_tf_deit.py index c7c1fc84568b..223d164d4aaf 100644 --- a/tests/models/deit/test_modeling_tf_deit.py +++ b/tests/models/deit/test_modeling_tf_deit.py @@ -130,7 +130,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label model = TFDeiTForMaskedImageModeling(config=config) result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) ) # test greyscale images @@ -139,7 +139,7 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size