From 9e72fcb8c1ba300f838e4c834c6eaced2ae13b50 Mon Sep 17 00:00:00 2001 From: robotjelly Date: Tue, 3 May 2022 21:44:21 +0530 Subject: [PATCH 1/3] type hints for pytorch models --- .../models/canine/modeling_canine.py | 197 ++++++------ .../models/convbert/modeling_convbert.py | 100 +++--- .../models/convnext/modeling_convnext.py | 34 ++- .../modeling_decision_transformer.py | 69 +++-- .../modeling_encoder_decoder.py | 30 +- src/transformers/models/glpn/modeling_glpn.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 183 +++++------ src/transformers/models/gptj/modeling_gptj.py | 135 ++++---- .../models/maskformer/modeling_maskformer.py | 24 +- .../megatron_bert/modeling_megatron_bert.py | 289 +++++++++--------- .../models/mobilebert/modeling_mobilebert.py | 182 +++++------ .../models/perceiver/modeling_perceiver.py | 200 +++++++----- .../models/retribert/modeling_retribert.py | 12 +- .../models/segformer/modeling_segformer.py | 2 +- src/transformers/models/swin/modeling_swin.py | 108 ++++--- .../models/transfo_xl/modeling_transfo_xl.py | 42 +-- src/transformers/models/van/modeling_van.py | 35 ++- 17 files changed, 890 insertions(+), 754 deletions(-) diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index e9cad1d15b86..4e7d9f2b1727 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -19,7 +19,7 @@ import math import os from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -253,11 +253,11 @@ def _embed_hash_buckets(self, input_ids, embedding_size: int, num_hashes: int, n def forward( self, - input_ids=None, - token_type_ids=None, - position_ids=None, - inputs_embeds=None, - ): + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: if input_ids is not None: input_shape = input_ids.size() else: @@ -356,7 +356,11 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, inputs, final_seq_char_positions=None): + def forward( + self, + inputs: torch.Tensor, + final_seq_char_positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: # inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final] # we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq] inputs = torch.transpose(inputs, 1, 2) @@ -419,12 +423,12 @@ def transpose_for_scores(self, x): def forward( self, - from_tensor, - to_tensor, - attention_mask=None, - head_mask=None, - output_attentions=False, - ): + from_tensor: torch.Tensor, + to_tensor: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: mixed_query_layer = self.query(from_tensor) # If this is instantiated as a cross-attention module, the keys @@ -496,7 +500,9 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward( + self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) @@ -574,11 +580,11 @@ def prune_heads(self, heads): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - output_attentions=False, - ): + hidden_states: Tuple[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: if not self.local: self_outputs = self.self(hidden_states, hidden_states, attention_mask, head_mask, output_attentions) attention_output = self_outputs[0] @@ -656,7 +662,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -669,7 +675,7 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) @@ -706,11 +712,11 @@ def __init__( def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - output_attentions=False, - ): + hidden_states: Tuple[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: self_attention_outputs = self.attention( hidden_states, attention_mask, @@ -767,13 +773,13 @@ def __init__( def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: Tuple[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -822,7 +828,7 @@ def __init__(self, config): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def forward(self, hidden_states): + def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor: # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] @@ -841,7 +847,7 @@ def __init__(self, config): self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states): + def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor: hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) @@ -862,7 +868,7 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def forward(self, hidden_states): + def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor: hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states @@ -873,7 +879,10 @@ def __init__(self, config): super().__init__() self.predictions = CanineLMPredictionHead(config) - def forward(self, sequence_output): + def forward( + self, + sequence_output: Tuple[torch.Tensor], + ) -> Tuple[torch.Tensor]: prediction_scores = self.predictions(sequence_output) return prediction_scores @@ -1093,16 +1102,16 @@ def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: torch.Tens ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CanineModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1130,12 +1139,12 @@ def forward( # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) molecule_attention_mask = self._downsample_attention_mask( attention_mask, downsampling_rate=self.config.downsampling_rate ) extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask( - molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]) + molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device ) # Prepare head mask if needed @@ -1275,17 +1284,17 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1372,17 +1381,17 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MultipleChoiceModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., @@ -1465,17 +1474,17 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. @@ -1543,18 +1552,18 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - start_positions=None, - end_positions=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index d7216b879eed..3a3b44b98644 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -201,7 +201,13 @@ def __init__(self, config): persistent=False, ) - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.LongTensor: if input_ids is not None: input_shape = input_ids.size() else: @@ -287,7 +293,7 @@ def __init__(self, config, input_filters, output_filters, kernel_size, **kwargs) self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range) self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x = self.depthwise(hidden_states) x = self.pointwise(x) x += self.bias @@ -341,12 +347,12 @@ def transpose_for_scores(self, x): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) batch_size = hidden_states.size(0) # If this is instantiated as a cross-attention module, the keys @@ -426,7 +432,7 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) @@ -460,12 +466,12 @@ def prune_heads(self, heads): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]: self_outputs = self.self( hidden_states, attention_mask, @@ -489,7 +495,7 @@ def __init__(self, input_size, output_size, num_groups): self.weight = nn.Parameter(torch.empty(self.num_groups, self.group_in_dim, self.group_out_dim)) self.bias = nn.Parameter(torch.empty(output_size)) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = list(hidden_states.size())[0] x = torch.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]) x = x.permute(1, 0, 2) @@ -514,7 +520,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -532,7 +538,7 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) @@ -556,13 +562,13 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]: self_attention_outputs = self.attention( hidden_states, attention_mask, @@ -608,15 +614,15 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -684,7 +690,7 @@ def __init__(self, config): self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) @@ -795,16 +801,16 @@ class PreTrainedModel ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -864,7 +870,7 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.dense = nn.Linear(config.hidden_size, config.embedding_size) - def forward(self, generator_hidden_states): + def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.dense(generator_hidden_states) hidden_states = get_activation("gelu")(hidden_states) hidden_states = self.LayerNorm(hidden_states) @@ -966,7 +972,7 @@ def __init__(self, config): self.config = config - def forward(self, hidden_states, **kwargs): + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index a0ca01d4c8ac..248566a938d2 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -15,6 +15,8 @@ """ PyTorch ConvNext model.""" +from typing import Optional, Tuple, Union + import torch import torch.utils.checkpoint from torch import nn @@ -78,7 +80,7 @@ def __init__(self, drop_prob=None): super().__init__() self.drop_prob = drop_prob - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) @@ -98,7 +100,7 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): raise NotImplementedError(f"Unsupported data format: {self.data_format}") self.normalized_shape = (normalized_shape,) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.data_format == "channels_last": x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": @@ -121,7 +123,7 @@ def __init__(self, config): ) self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first") - def forward(self, pixel_values): + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: embeddings = self.patch_embeddings(pixel_values) embeddings = self.layernorm(embeddings) return embeddings @@ -155,7 +157,7 @@ def __init__(self, config, dim, drop_path=0): ) self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: input = hidden_states x = self.dwconv(hidden_states) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) @@ -197,7 +199,7 @@ def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, d *[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)] ) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: hidden_states = self.downsampling_layer(hidden_states) hidden_states = self.layers(hidden_states) return hidden_states @@ -224,7 +226,12 @@ def __init__(self, config): cur += config.depths[i] prev_chs = out_chs - def forward(self, hidden_states, output_hidden_states=False, return_dict=True): + def forward( + self, + hidden_states: torch.FloatTensor, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: all_hidden_states = () if output_hidden_states else None for i, layer_module in enumerate(self.stages): @@ -325,7 +332,12 @@ def __init__(self, config): modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, ) - def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None): + def forward( + self, + pixel_values: torch.FloatTensor = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -387,7 +399,13 @@ def __init__(self, config): config_class=_CONFIG_FOR_DOC, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) - def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None): + def forward( + self, + pixel_values: torch.FloatTensor = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index b36ece31b731..f245f5654a82 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -278,15 +278,18 @@ def _merge_heads(self, tensor, num_heads, attn_head_size): def forward( self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=False, - output_attentions=False, - ): + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( @@ -340,7 +343,7 @@ def __init__(self, intermediate_size, config): self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) - def forward(self, hidden_states): + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.c_proj(hidden_states) @@ -369,15 +372,15 @@ def __init__(self, config, layer_idx=None): def forward( self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=False, - output_attentions=False, - ): + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( @@ -510,20 +513,20 @@ def set_input_embeddings(self, new_embeddings): # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 6eae4c876f44..a5e1b8311f9d 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -15,7 +15,7 @@ """ Classes to support Encoder-Decoder architectures""" import warnings -from typing import Optional +from typing import Optional, Tuple, Union import torch from torch import nn @@ -430,21 +430,21 @@ def from_encoder_decoder_pretrained( @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, - ): + ) -> Union[Tuple, Seq2SeqLMOutput]: r""" Returns: diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index 2361e8e61b4a..b7fc18b1d0f9 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -80,7 +80,7 @@ def __init__(self, drop_prob=None): super().__init__() self.drop_prob = drop_prob - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 00df05b39063..94d742cf18c5 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -18,7 +18,7 @@ import math import os from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -289,15 +289,18 @@ def _merge_heads(self, tensor, num_heads, attn_head_size): def forward( self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=False, - output_attentions=False, - ): + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( @@ -350,7 +353,7 @@ def __init__(self, intermediate_size, config): self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) - def forward(self, hidden_states): + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.c_proj(hidden_states) @@ -376,15 +379,15 @@ def __init__(self, config, layer_idx=None): def forward( self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=False, - output_attentions=False, - ): + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( @@ -742,20 +745,20 @@ def _prune_heads(self, heads_to_prune): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1020,21 +1023,21 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -1189,22 +1192,22 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - mc_token_ids=None, - labels=None, - mc_labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, - ): + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: r""" mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - @@ -1352,19 +1355,19 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1488,19 +1491,19 @@ def __init__(self, config): # fmt: on def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 53dc690fd6eb..d10c266d3f0e 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch GPT-J model.""" -from typing import Tuple +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -191,13 +191,16 @@ def _attn( def forward( self, - hidden_states, - attention_mask=None, - layer_past=None, - head_mask=None, - use_cache=False, - output_attentions=False, - ): + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) @@ -271,7 +274,7 @@ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) - def forward(self, hidden_states): + def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor: hidden_states = self.fc_in(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.fc_out(hidden_states) @@ -289,13 +292,13 @@ def __init__(self, config): def forward( self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - use_cache=False, - output_attentions=False, - ): + hidden_states: Optional[torch.FloatTensor], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( @@ -533,18 +536,18 @@ def set_input_embeddings(self, new_embeddings): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -787,19 +790,19 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -911,19 +914,19 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1038,18 +1041,18 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - start_positions=None, - end_positions=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 339de6eeeb10..bfb020895ce9 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -702,11 +702,11 @@ def transpose_for_scores(self, x): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape mixed_query_layer = self.query(hidden_states) @@ -764,7 +764,7 @@ def __init__(self, config, dim): self.dense = nn.Linear(dim, dim) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -797,7 +797,13 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -814,7 +820,7 @@ def __init__(self, config, dim): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -827,7 +833,7 @@ def __init__(self, config, dim): self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 882798f68f07..24a06a336865 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -20,7 +20,7 @@ import os import warnings from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -154,8 +154,13 @@ def __init__(self, config): self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") def forward( - self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 - ): + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() else: @@ -319,7 +324,7 @@ def __init__(self, config): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, residual): + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return residual + hidden_states @@ -354,14 +359,14 @@ def prune_heads(self, heads): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: ln_outputs = self.ln(hidden_states) self_outputs = self.self( ln_outputs, @@ -400,7 +405,7 @@ def __init__(self, config): self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return input_tensor + hidden_states @@ -425,14 +430,14 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( @@ -507,17 +512,17 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -873,20 +878,20 @@ class PreTrainedModel ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -1022,18 +1027,18 @@ def set_output_embeddings(self, new_embeddings): @replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - next_sentence_label=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + next_sentence_label: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MegatronBertForPreTrainingOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., @@ -1133,21 +1138,21 @@ def set_output_embeddings(self, new_embeddings): @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -1287,19 +1292,19 @@ def set_output_embeddings(self, new_embeddings): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., @@ -1379,18 +1384,18 @@ def __init__(self, config): @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs - ): + ) -> Union[Tuple, NextSentencePredictorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair @@ -1489,17 +1494,17 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1588,17 +1593,17 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MultipleChoiceModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., @@ -1684,17 +1689,17 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. @@ -1765,18 +1770,18 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - start_positions=None, - end_positions=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 87db30757db9..6d2b2d3ce2e5 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -164,7 +164,7 @@ def __init__(self, feat_size, eps=None): self.bias = nn.Parameter(torch.zeros(feat_size)) self.weight = nn.Parameter(torch.ones(feat_size)) - def forward(self, input_tensor): + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: return input_tensor * self.weight + self.bias @@ -194,7 +194,13 @@ def __init__(self, config): # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() else: @@ -260,13 +266,13 @@ def transpose_for_scores(self, x): def forward( self, - query_tensor, - key_tensor, - value_tensor, - attention_mask=None, - head_mask=None, - output_attentions=None, - ): + query_tensor: torch.Tensor, + key_tensor: torch.Tensor, + value_tensor: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(query_tensor) mixed_key_layer = self.key(key_tensor) mixed_value_layer = self.value(value_tensor) @@ -306,7 +312,7 @@ def __init__(self, config): if not self.use_bottleneck: self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, residual_tensor): + def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor: layer_outputs = self.dense(hidden_states) if not self.use_bottleneck: layer_outputs = self.dropout(layer_outputs) @@ -341,14 +347,14 @@ def prune_heads(self, heads): def forward( self, - query_tensor, - key_tensor, - value_tensor, - layer_input, - attention_mask=None, - head_mask=None, - output_attentions=None, - ): + query_tensor: torch.Tensor, + key_tensor: torch.Tensor, + value_tensor: torch.Tensor, + layer_input: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> Tuple[torch.Tensor]: self_outputs = self.self( query_tensor, key_tensor, @@ -373,7 +379,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -386,7 +392,7 @@ def __init__(self, config): self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, residual_tensor): + def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor: layer_outputs = self.dense(hidden_states) layer_outputs = self.dropout(layer_outputs) layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) @@ -404,7 +410,9 @@ def __init__(self, config): else: self.bottleneck = OutputBottleneck(config) - def forward(self, intermediate_states, residual_tensor_1, residual_tensor_2): + def forward( + self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor + ) -> torch.Tensor: layer_output = self.dense(intermediate_states) if not self.use_bottleneck: layer_output = self.dropout(layer_output) @@ -421,7 +429,7 @@ def __init__(self, config): self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size) self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: layer_input = self.dense(hidden_states) layer_input = self.LayerNorm(layer_input) return layer_input @@ -436,7 +444,7 @@ def __init__(self, config): if self.key_query_shared_bottleneck: self.attention = BottleneckLayer(config) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: # This method can return three different tuples of values. These different values make use of bottlenecks, # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory # usage. These linear layer have weights that are learned during training. @@ -469,7 +477,7 @@ def __init__(self, config): self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size) self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states, residual_tensor): + def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor: layer_outputs = self.dense(hidden_states) layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) return layer_outputs @@ -481,7 +489,7 @@ def __init__(self, config): self.intermediate = MobileBertIntermediate(config) self.output = FFNOutput(config) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: intermediate_output = self.intermediate(hidden_states) layer_outputs = self.output(intermediate_output, hidden_states) return layer_outputs @@ -503,11 +511,11 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - output_attentions=None, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> Tuple[torch.Tensor]: if self.use_bottleneck: query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states) else: @@ -557,13 +565,13 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): @@ -599,7 +607,7 @@ def __init__(self, config): if self.do_activate: self.dense = nn.Linear(config.hidden_size, config.hidden_size) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] @@ -621,7 +629,7 @@ def __init__(self, config): self.transform_act_fn = config.hidden_act self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) @@ -640,7 +648,7 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.transform(hidden_states) hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0)) hidden_states += self.decoder.bias @@ -652,7 +660,7 @@ def __init__(self, config): super().__init__() self.predictions = MobileBertLMPredictionHead(config) - def forward(self, sequence_output): + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: prediction_scores = self.predictions(sequence_output) return prediction_scores @@ -663,7 +671,7 @@ def __init__(self, config): self.predictions = MobileBertLMPredictionHead(config) self.seq_relationship = nn.Linear(config.hidden_size, 2) - def forward(self, sequence_output, pooled_output): + def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> Tuple[torch.Tensor]: prediction_scores = self.predictions(sequence_output) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score @@ -841,16 +849,16 @@ class PreTrainedModel ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_hidden_states=None, - output_attentions=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -943,18 +951,18 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Em @replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - next_sentence_label=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + next_sentence_label: Optional[torch.LongTensor] = None, + output_attentions: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[torch.FloatTensor] = None, + return_dict: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, MobileBertForPreTrainingOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., @@ -1059,17 +1067,17 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Em ) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., @@ -1115,7 +1123,7 @@ def __init__(self, config): super().__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) - def forward(self, pooled_output): + def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: seq_relationship_score = self.seq_relationship(pooled_output) return seq_relationship_score @@ -1138,18 +1146,18 @@ def __init__(self, config): @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs, - ): + ) -> Union[Tuple, NextSentencePredictorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index b8acef6c226c..b0a306203944 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from functools import reduce from operator import __add__ -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import numpy as np import torch @@ -177,7 +177,7 @@ def __init__(self, config): super().__init__() self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents)) - def forward(self, batch_size): + def forward(self, batch_size: int): return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang @@ -232,13 +232,13 @@ def transpose_for_scores(self, x, channels_per_head): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - inputs=None, - inputs_mask=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: hidden_states = self.layernorm1(hidden_states) inputs = self.layernorm2(inputs) @@ -301,7 +301,7 @@ def __init__(self, config, input_channels, output_channels): super().__init__() self.dense = nn.Linear(input_channels, output_channels) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) return hidden_states @@ -377,13 +377,13 @@ def prune_heads(self, heads): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - inputs=None, - inputs_mask=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, attention_mask, @@ -418,7 +418,7 @@ def __init__(self, config, input_size, widening_factor): self.intermediate_act_fn = config.hidden_act self.dense2 = nn.Linear(widening_factor * input_size, input_size) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense1(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.dense2(hidden_states) @@ -456,13 +456,13 @@ def __init__( def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - inputs=None, - inputs_mask=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: attention_outputs = self.attention( hidden_states, attention_mask, @@ -543,15 +543,15 @@ def __init__(self, config, kv_dim=None): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - inputs=None, - inputs_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None @@ -754,14 +754,14 @@ class PreTrainedModel @replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - inputs, - attention_mask=None, - subsampled_output_points=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + inputs: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PerceiverModelOutput]: r""" Returns: @@ -1871,7 +1871,7 @@ def forward( self, inputs: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - subsampled_output_points: Optional[Dict[str, torch.tensor]] = None, + subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -2020,7 +2020,9 @@ def __init__(self, config): def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): return None - def forward(self, query, z, query_mask=None): + def forward( + self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: # (batch_size, num_latents, d_latents) -> (batch_size, d_latents) z = torch.mean(z, dim=1) # (batch_size, d_latents) -> (batch_size, config.num_labels) @@ -2044,11 +2046,11 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): The type of position encoding to use. Can be either "trainable", "fourier", or "none". output_index_dims (`int`, *optional*): The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'. - num_channels (`int`, *optional*): + num_channels (`int`, *optional*, defaults to 128): The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'. qk_channels (`int`, *optional*): The number of channels of the queries and keys in the cross-attention layer. - v_channels (`int`, *optional*, defaults to 128): + v_channels (`int`, *optional*): The number of channels of the values in the cross-attention layer. num_heads (`int`, *optional*, defaults to 1): The number of attention heads in the cross-attention layer. @@ -2066,23 +2068,23 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): def __init__( self, - config, - output_num_channels, - position_encoding_type="trainable", + config: PerceiverConfig, + output_num_channels: Optional[int], + position_encoding_type: Optional[str] = "trainable", # The following 2 arguments are ignored if position_encoding_type == 'none': - output_index_dims=None, - num_channels=128, - subsampled_index_dims=None, - qk_channels=None, - v_channels=None, - num_heads=1, - widening_factor=1, - use_query_residual=False, - concat_preprocessed_input=False, - final_project=True, - position_encoding_only=False, + output_index_dims: Optional[int] = None, + num_channels: Optional[int] = 128, + subsampled_index_dims: Optional[int] = None, + qk_channels: Optional[int] = None, + v_channels: Optional[int] = None, + num_heads: Optional[int] = 1, + widening_factor: Optional[int] = 1, + use_query_residual: Optional[bool] = False, + concat_preprocessed_input: Optional[bool] = False, + final_project: Optional[bool] = True, + position_encoding_only: Optional[bool] = False, **position_encoding_kwargs, - ): + ) -> None: super().__init__() self.output_num_channels = output_num_channels @@ -2183,7 +2185,13 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su return pos_emb - def forward(self, query, z, query_mask=None, output_attentions=False): + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: # Cross-attention decoding. # key, value: B x N x K; query: B x M x K # Attention maps -> B x N x M @@ -2239,7 +2247,13 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points ) - def forward(self, query, z, query_mask=None, output_attentions=False): + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) # B x 1 x num_classes -> B x num_classes @@ -2268,7 +2282,13 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su raise ValueError("FlowDecoder doesn't support subsampling yet.") return inputs - def forward(self, query, z, query_mask=None, output_attentions=False): + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) preds = decoder_outputs.logits # Output flow and rescale. @@ -2291,7 +2311,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder): The type of position encoding to use. Can be either "trainable", "fourier", or "none". """ - def __init__(self, config, output_shape, position_encoding_type, **decoder_kwargs): + def __init__( + self, config: PerceiverConfig, output_shape: List[int], position_encoding_type: str, **decoder_kwargs + ) -> None: super().__init__() if len(output_shape) != 4: # B, T, H, W raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.") @@ -2318,7 +2340,9 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su subsampled_points=subsampled_points, ) - def forward(self, query, z, query_mask=None): + def forward( + self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None + ) -> PerceiverDecoderOutput: decoder_outputs = self.decoder(query, z) logits = decoder_outputs.logits @@ -2378,14 +2402,14 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder): def __init__( self, - config, - modalities, - num_outputs, - output_num_channels, - min_padding_size=2, - subsampled_index_dims=None, + config: PerceiverConfig, + modalities: Dict[str, PerceiverAbstractDecoder], + num_outputs: int, + output_num_channels: int, + min_padding_size: Optional[int] = 2, + subsampled_index_dims: Optional[Dict[str, PerceiverAbstractDecoder]] = None, **decoder_kwargs - ): + ) -> None: super().__init__() self.modalities = nn.ModuleDict(modalities) self.subsampled_index_dims = subsampled_index_dims @@ -2447,7 +2471,13 @@ def embed(modality, x): [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1 ) - def forward(self, query, z, query_mask=None, output_attentions=False): + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: # B x 1 x num_classes -> B x num_classes decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) @@ -2680,7 +2710,7 @@ def num_dimensions(self) -> int: def output_size(self, *args, **kwargs) -> int: return self._num_channels - def forward(self, batch_size): + def forward(self, batch_size: int) -> torch.Tensor: position_embeddings = self.position_embeddings if batch_size is not None: @@ -2741,7 +2771,9 @@ def output_size(self): return encoding_size - def forward(self, index_dims, batch_size, device, pos=None): + def forward( + self, index_dims: List[int], batch_size: int, device, pos: torch.FloatTensor = None + ) -> torch.FloatTensor: pos = _check_or_build_spatial_positions(pos, index_dims, batch_size) fourier_pos_enc = generate_fourier_features( pos, @@ -2771,7 +2803,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor): Model configuration. """ - def __init__(self, config): + def __init__(self, config: PerceiverConfig) -> None: super().__init__() self.config = config self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model) @@ -2781,7 +2813,7 @@ def __init__(self, config): def num_channels(self) -> int: return self.config.d_model - def forward(self, inputs): + def forward(self, inputs: torch.LongTensor) -> torch.FloatTensor: embeddings = self.embeddings(inputs) seq_length = inputs.shape[1] @@ -2800,13 +2832,13 @@ class PerceiverEmbeddingDecoder(nn.Module): Model configuration. """ - def __init__(self, config): + def __init__(self, config: PerceiverConfig) -> None: super().__init__() self.config = config self.vocab_size = config.vocab_size self.bias = nn.Parameter(torch.zeros(self.vocab_size)) - def forward(self, hidden_states, embedding_layer): + def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor: batch_size, seq_len, d_model = hidden_states.shape output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim output = output + self.bias @@ -2859,7 +2891,7 @@ class PerceiverClassificationPostprocessor(nn.Module): Number of channels in the input. """ - def __init__(self, config, in_channels): + def __init__(self, config: PerceiverConfig, in_channels: int) -> None: super().__init__() self.classifier = nn.Linear(in_channels, config.num_labels) @@ -2881,7 +2913,7 @@ class PerceiverAudioPostprocessor(nn.Module): Postprocessor type to use. Currently, only "patches" is supported. """ - def __init__(self, config, in_channels, postproc_type: str = "patches"): + def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None: super().__init__() if postproc_type not in ("patches",): # to be supported: 'conv', 'patches', 'pixels' @@ -2908,7 +2940,7 @@ class PerceiverProjectionPostprocessor(nn.Module): Number of channels in the output. """ - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.classifier = nn.Linear(in_channels, out_channels) @@ -3155,7 +3187,7 @@ class PerceiverOneHotPreprocessor(AbstractPreprocessor): Model configuration. """ - def __init__(self, config): + def __init__(self, config: PerceiverConfig) -> None: super().__init__() self.config: PerceiverConfig = config diff --git a/src/transformers/models/retribert/modeling_retribert.py b/src/transformers/models/retribert/modeling_retribert.py index 7d0a42b5fb13..5a12c962e292 100644 --- a/src/transformers/models/retribert/modeling_retribert.py +++ b/src/transformers/models/retribert/modeling_retribert.py @@ -18,6 +18,7 @@ import math +from typing import Optional import torch import torch.utils.checkpoint as checkpoint @@ -85,7 +86,7 @@ def _init_weights(self, module): RETRIBERT_START_DOCSTRING, ) class RetriBertModel(RetriBertPreTrainedModel): - def __init__(self, config): + def __init__(self, config: RetriBertConfig) -> None: super().__init__(config) self.projection_dim = config.projection_dim @@ -173,8 +174,13 @@ def embed_answers( return self.project_doc(a_reps) def forward( - self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1 - ): + self, + input_ids_query: torch.LongTensor, + attention_mask_query: Optional[torch.FloatTensor], + input_ids_doc: torch.LongTensor, + attention_mask_doc: Optional[torch.FloatTensor], + checkpoint_batch_size: int = -1, + ) -> torch.FloatTensor: r""" Args: input_ids_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index d8989e340c3c..55ac976b3544 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -112,7 +112,7 @@ def __init__(self, drop_prob=None): super().__init__() self.drop_prob = drop_prob - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 51a19ab73b8c..727399f17f4d 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -18,7 +18,7 @@ import collections.abc import math from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -272,7 +272,9 @@ def __init__(self, config, use_mask_token=False): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values, bool_masked_pos=None): + def forward( + self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + ) -> Tuple[torch.Tensor]: embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -317,7 +319,7 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values): + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: _, _, height, width = pixel_values.shape # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) @@ -342,7 +344,7 @@ class SwinPatchMerging(nn.Module): Normalization layer class. """ - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: super().__init__() self.input_resolution = input_resolution self.dim = dim @@ -357,7 +359,7 @@ def maybe_pad(self, input_feature, height, width): return input_feature - def forward(self, input_feature, input_dimensions): + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: height, width = input_dimensions # `dim` is height * width batch_size, dim, num_channels = input_feature.shape @@ -438,11 +440,11 @@ def transpose_for_scores(self, x): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape mixed_query_layer = self.query(hidden_states) @@ -499,7 +501,7 @@ def __init__(self, config, dim): self.dense = nn.Linear(dim, dim) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -531,7 +533,13 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -547,7 +555,7 @@ def __init__(self, config, dim): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -559,7 +567,7 @@ def __init__(self, config, dim): self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states @@ -621,7 +629,13 @@ def maybe_pad(self, hidden_states, height, width): hidden_states = nn.functional.pad(hidden_states, pad_values) return hidden_states, pad_values - def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: self.set_shift_and_window_size(input_dimensions) height, width = input_dimensions batch_size, _, channels = hidden_states.size() @@ -703,7 +717,13 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d self.pointing = False - def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: height, width = input_dimensions for i, layer_module in enumerate(self.blocks): @@ -752,13 +772,13 @@ def __init__(self, config, grid_size): def forward( self, - hidden_states, - input_dimensions, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, SwinEncoderOutput]: all_input_dimensions = () all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None @@ -920,13 +940,13 @@ class PreTrainedModel ) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SwinModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -999,13 +1019,13 @@ def __init__(self, config): @replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SwinMaskedImageModelingOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -1115,13 +1135,13 @@ def __init__(self, config): ) def forward( self, - pixel_values=None, - head_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SwinImageClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index f566262ff294..556525cbf6c8 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -881,14 +881,14 @@ def _update_mems(self, hids, mems, mlen, qlen): ) def forward( self, - input_ids=None, - mems=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + mems: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TransfoXLModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1071,15 +1071,15 @@ def init_mems(self, bsz): ) def forward( self, - input_ids=None, - mems=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.LongTensor] = None, + mems: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TransfoXLLMHeadModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -1215,11 +1215,11 @@ def __init__(self, config): ) def forward( self, - input_ids: Optional[torch.Tensor] = None, + input_ids: Optional[torch.LongTensor] = None, mems: Optional[List[torch.FloatTensor]] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, diff --git a/src/transformers/models/van/modeling_van.py b/src/transformers/models/van/modeling_van.py index 3ea59cbae8d3..7a7030c2f569 100644 --- a/src/transformers/models/van/modeling_van.py +++ b/src/transformers/models/van/modeling_van.py @@ -16,6 +16,7 @@ import math from collections import OrderedDict +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -81,7 +82,7 @@ def __init__(self, drop_prob=None): super().__init__() self.drop_prob = drop_prob - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) @@ -146,7 +147,7 @@ def __init__(self, hidden_size: int): super().__init__() self.attention = VanLargeKernelAttention(hidden_size) - def forward(self, hidden_state): + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: attention = self.attention(hidden_state) attended = hidden_state * attention return attended @@ -171,7 +172,7 @@ def __init__(self, hidden_size: int, hidden_act: str = "gelu"): self.attention_layer = VanLargeKernelAttentionLayer(hidden_size) self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1) - def forward(self, hidden_state): + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: residual = hidden_state hidden_state = self.pre_projection(hidden_state) hidden_state = self.attention_layer(hidden_state) @@ -189,7 +190,7 @@ def __init__(self, hidden_size: int, initial_value: float = 1e-2): super().__init__() self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True) - def forward(self, hidden_state): + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: # unsqueezing for broadcasting hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state return hidden_state @@ -218,7 +219,7 @@ def __init__( ) self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value) - def forward(self, hidden_state): + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: residual = hidden_state # attention hidden_state = self.pre_normomalization(hidden_state) @@ -269,7 +270,7 @@ def __init__( ) self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_state): + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.embeddings(hidden_state) hidden_state = self.layers(hidden_state) # rearrange b c h w -> b (h w) c @@ -316,7 +317,12 @@ def __init__(self, config: VanConfig): ) ) - def forward(self, hidden_state, output_hidden_states=False, return_dict=True): + def forward( + self, + hidden_state: torch.Tensor, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: all_hidden_states = () if output_hidden_states else None for _, stage_module in enumerate(self.stages): @@ -411,7 +417,12 @@ def __init__(self, config): modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, ) - def forward(self, pixel_values, output_hidden_states=None, return_dict=None): + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -463,7 +474,13 @@ def __init__(self, config): config_class=_CONFIG_FOR_DOC, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) - def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None): + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., From e50750ae7f9a273e7f039a2e9e66231aa1cf5197 Mon Sep 17 00:00:00 2001 From: robotjelly Date: Wed, 4 May 2022 00:00:53 +0530 Subject: [PATCH 2/3] fixed import error --- .../decision_transformer/modeling_decision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index f245f5654a82..1908c8cf8ebe 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -17,7 +17,7 @@ import math import os from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint From 9e3203338599e4c23bcd427dcfdc283454905b38 Mon Sep 17 00:00:00 2001 From: robotjelly Date: Wed, 4 May 2022 20:01:55 +0530 Subject: [PATCH 3/3] fixed some errors --- src/transformers/models/canine/modeling_canine.py | 4 ++-- .../decision_transformer/modeling_decision_transformer.py | 5 +---- src/transformers/models/gpt2/modeling_gpt2.py | 5 +---- src/transformers/models/perceiver/modeling_perceiver.py | 2 +- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 4e7d9f2b1727..b93a7b3d4689 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -1139,12 +1139,12 @@ def forward( # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) molecule_attention_mask = self._downsample_attention_mask( attention_mask, downsampling_rate=self.config.downsampling_rate ) extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask( - molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device + molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]) ) # Prepare head mask if needed diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 1908c8cf8ebe..dcad3786d83c 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -286,10 +286,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - ) -> Union[ - Tuple[torch.Tensor, Tuple[torch.Tensor]], - Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], - ]: + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 94d742cf18c5..c79dfd76a0f2 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -297,10 +297,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - ) -> Union[ - Tuple[torch.Tensor, Tuple[torch.Tensor]], - Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], - ]: + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index b0a306203944..6dc1563b47f0 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -2069,7 +2069,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): def __init__( self, config: PerceiverConfig, - output_num_channels: Optional[int], + output_num_channels: int, position_encoding_type: Optional[str] = "trainable", # The following 2 arguments are ignored if position_encoding_type == 'none': output_index_dims: Optional[int] = None,