Skip to content
Merged
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@
is_torch_xpu_available,
logging,
)
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder, is_flash_attention_requested
from .utils.generic import GeneralInterface, is_flash_attention_requested
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
is_huggingface_hub_greater_or_equal,
is_sagemaker_mp_enabled,
is_tracing,
)
from .utils.loading_report import LoadStateDictInfo, log_state_dict_report
from .utils.output_capturing import _CAN_RECORD_REGISTRY, OutputRecorder
from .utils.quantization_config import QuantizationMethod


Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
from ...utils.generic import check_model_inputs, maybe_autocast
from ...utils.generic import check_model_inputs, maybe_autocast, merge_with_config_defaults
from ..auto import AutoModel
from .configuration_aria import AriaConfig, AriaTextConfig

Expand Down Expand Up @@ -915,7 +915,8 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@merge_with_config_defaults
@auto_docstring(
custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
)
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_python import PreTokenizedInput, TextInput
from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import check_model_inputs
from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
Expand Down Expand Up @@ -1260,10 +1259,6 @@ def _create_patch_attention_mask(self, pixel_mask):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring(
custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand All @@ -1290,8 +1285,6 @@ def get_image_features(

return image_outputs

@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check
from ...utils.generic import check_model_inputs
from ...utils.generic import can_return_tuple, merge_with_config_defaults
from ..auto import AutoModel
from .configuration_aya_vision import AyaVisionConfig

Expand Down Expand Up @@ -179,7 +179,8 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@merge_with_config_defaults
@auto_docstring(
custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
)
Expand Down Expand Up @@ -241,7 +242,7 @@ def get_placeholder_mask(
)
return special_image_mask

@check_model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -339,7 +340,7 @@ def get_image_features(
**kwargs,
)

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/aya_vision/modular_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...modeling_outputs import BaseModelOutputWithPooling
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils.generic import check_model_inputs
from ...utils.generic import can_return_tuple, merge_with_config_defaults
from .configuration_aya_vision import AyaVisionConfig


Expand Down Expand Up @@ -104,7 +104,8 @@ class AyaVisionModelOutputWithPast(LlavaModelOutputWithPast):

class AyaVisionModel(LlavaModel):
# Unlike LLaVA, the model doesn't have to deal with Pixtral-style image states
@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@merge_with_config_defaults
@auto_docstring(
custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
)
Expand Down Expand Up @@ -142,7 +143,7 @@ def get_image_features(

return image_outputs

@check_model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
logging,
torch_int,
)
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.generic import check_model_inputs
from ...utils.output_capturing import OutputRecorder
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig

Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/blt/modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from ...utils.generic import check_model_inputs, maybe_autocast
from ...utils.output_capturing import OutputRecorder
from .configuration_blt import (
BltConfig,
BltGlobalTransformerConfig,
Expand Down Expand Up @@ -437,8 +438,8 @@ class BltPreTrainedModel(PreTrainedModel):
_supports_flex_attn = False
_supports_attention_backend = False
_can_record_outputs = {
"hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"),
"attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
"hidden_states": OutputRecorder(BltTransformerLayer, index=0),
"attentions": OutputRecorder(BltSelfAttention, index=1),
}

@torch.no_grad()
Expand Down Expand Up @@ -741,7 +742,6 @@ def __init__(self, config: BltLocalDecoderConfig):

self.post_init()

@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor | None = None,
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/blt/modular_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from ...utils.generic import check_model_inputs, maybe_autocast
from ...utils.output_capturing import OutputRecorder
from ..cohere2.modeling_cohere2 import rotate_half # noqa: F401
from ..llama.modeling_llama import LlamaRotaryEmbedding
from ..mllama.modeling_mllama import (
Expand Down Expand Up @@ -355,8 +356,8 @@ class BltPreTrainedModel(MllamaPreTrainedModel):
_supports_flex_attn = False
_no_split_modules = ["BltTransformerLayer"]
_can_record_outputs = {
"hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"),
"attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
"hidden_states": OutputRecorder(BltTransformerLayer, index=0),
"attentions": OutputRecorder(BltSelfAttention, index=1),
}

# Weight initialization is adapted from:
Expand Down Expand Up @@ -673,7 +674,6 @@ def __init__(self, config: BltLocalDecoderConfig):

self.post_init()

@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor | None = None,
Expand Down
43 changes: 28 additions & 15 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
ModelOutput,
TransformersKwargs,
auto_docstring,
can_return_tuple,
logging,
torch_int,
)
from ...utils.generic import check_model_inputs
from ...utils.generic import can_return_tuple, check_model_inputs
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig


Expand Down Expand Up @@ -515,9 +514,14 @@ def forward(
)


class CLIPTextTransformer(nn.Module):
class CLIPTextTransformer(CLIPPreTrainedModel):
config: CLIPTextConfig
input_modalities = ("text",)

_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]

def __init__(self, config: CLIPTextConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
Expand All @@ -526,7 +530,9 @@ def __init__(self, config: CLIPTextConfig):

# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
self.post_init()

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -587,8 +593,6 @@ def forward(
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


Expand All @@ -615,7 +619,6 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -648,17 +651,24 @@ def forward(
)


class CLIPVisionTransformer(nn.Module):
class CLIPVisionTransformer(CLIPPreTrainedModel):
config: CLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = ("image",)
_no_split_modules = ["CLIPEncoderLayer"]

def __init__(self, config: CLIPVisionConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size

self.embeddings = CLIPVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand All @@ -684,8 +694,6 @@ def forward(
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


Expand All @@ -709,7 +717,6 @@ def __init__(self, config: CLIPVisionConfig):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -967,7 +974,7 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1005,6 +1012,8 @@ def forward(
return CLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)


Expand All @@ -1028,7 +1037,7 @@ def __init__(self, config: CLIPVisionConfig):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1068,6 +1077,8 @@ def forward(
return CLIPVisionModelOutput(
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
)


Expand Down Expand Up @@ -1096,7 +1107,7 @@ def __init__(self, config: CLIPConfig) -> None:
# Initialize weights and apply final processing
self.post_init()

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1127,6 +1138,8 @@ def forward(
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
from ...utils.generic import check_model_inputs
from ..auto import AutoModel
from .configuration_cohere2_vision import Cohere2VisionConfig

Expand Down Expand Up @@ -202,7 +201,7 @@ def get_placeholder_mask(
)
return special_image_mask

@check_model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -279,7 +278,7 @@ def get_image_features(
) -> tuple | BaseModelOutputWithPooling:
return self.model.get_image_features(pixel_values=pixel_values, **kwargs)

@check_model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down
Loading