Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ class GenerationMixin(ContinuousMixin):
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""

# Should be overwritten by models that can generate non-text output
output_modalities = "text"

def adjust_generation_fn(
self,
generation_config,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
_supports_attention_backend = False
_can_record_outputs = None

# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
# Possible values are: text, image, video, audio and time
input_modalities: Union[str, list[str]] = "text" # most models are text

@property
@torch._dynamo.allow_in_graph
def can_record_outputs(self) -> dict[str, OutputRecorder]:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ class Aimv2PreTrainedModel(PreTrainedModel):

config: Aimv2Config
base_model_prefix = "aimv2"
input_modalities = "image"
supports_gradient_checkpointing = True
_no_split_modules = [
"Aimv2EncoderLayer",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aimv2/modular_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ class Aimv2PreTrainedModel(PreTrainedModel):

config: Aimv2Config
base_model_prefix = "aimv2"
input_modalities = "image"
supports_gradient_checkpointing = True
_no_split_modules = [
"Aimv2EncoderLayer",
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class AlignPreTrainedModel(PreTrainedModel):
config: AlignConfig
base_model_prefix = "align"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True

def _init_weights(self, module: nn.Module):
Expand Down Expand Up @@ -849,6 +850,7 @@ def _init_weights(self, module: nn.Module):
)
class AlignTextModel(AlignPreTrainedModel):
config: AlignTextConfig
input_modalities = "text"
_no_split_modules = ["AlignTextEmbeddings"]

def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True):
Expand Down Expand Up @@ -969,6 +971,7 @@ def forward(
class AlignVisionModel(AlignPreTrainedModel):
config: AlignVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"
supports_gradient_checkpointing = False

def __init__(self, config: AlignVisionConfig):
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=Fals
class AltCLIPPreTrainedModel(PreTrainedModel):
config: AltCLIPConfig
base_model_prefix = "altclip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_module = []

Expand Down Expand Up @@ -870,6 +871,7 @@ def forward(
class AltCLIPVisionModel(AltCLIPPreTrainedModel):
config: AltCLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: AltCLIPVisionConfig):
super().__init__(config)
Expand Down Expand Up @@ -1028,6 +1030,7 @@ def forward(

class AltCLIPTextModel(AltCLIPPreTrainedModel):
config: AltCLIPTextConfig
input_modalities = "text"

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def forward(
class AriaTextPreTrainedModel(PreTrainedModel):
config: AriaTextConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
class AriaTextPreTrainedModel(PreTrainedModel):
config: AriaTextConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
class ASTPreTrainedModel(PreTrainedModel):
config: ASTConfig
base_model_prefix = "audio_spectrogram_transformer"
input_modalities = "audio"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ def forward(
class AutoformerPreTrainedModel(PreTrainedModel):
config: AutoformerConfig
base_model_prefix = "model"
input_modalities = "time"
main_input_name = "past_values"
supports_gradient_checkpointing = True

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def pixel_shuffle(self, image_features): # B, S, D
class AyaVisionPreTrainedModel(PreTrainedModel):
config: AyaVisionConfig
base_model_prefix = ""
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def device(self) -> torch.device:
# GPT2-like autoregressive model
class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
config: BarkSubModelConfig
output_modalities = "audio"

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def forward(
class BeitPreTrainedModel(PreTrainedModel):
config: BeitConfig
base_model_prefix = "beit"
input_modalities = "image"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["BeitLayer"]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def forward(
class BitPreTrainedModel(PreTrainedModel):
config: BitConfig
base_model_prefix = "bit"
input_modalities = "image"
main_input_name = "pixel_values"
_no_split_modules = ["BitEmbeddings"]

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def forward(
class BlipPreTrainedModel(PreTrainedModel):
config: BlipConfig
base_model_prefix = "blip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
_skip_keys_device_placement = ["past_key_values"]
Expand Down Expand Up @@ -482,6 +483,7 @@ def forward(

class BlipVisionModel(BlipPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
config: BlipVisionConfig
_can_record_outputs = {
"hidden_states": BlipEncoderLayer,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def forward(
class Blip2PreTrainedModel(PreTrainedModel):
config: Blip2Config
base_model_prefix = "blip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn = True
Expand Down Expand Up @@ -474,6 +475,7 @@ def forward(
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
class Blip2VisionModel(Blip2PreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
config: Blip2VisionConfig
_can_record_outputs = {
"hidden_states": Blip2EncoderLayer,
Expand Down Expand Up @@ -1489,6 +1491,7 @@ def forward(
@auto_docstring
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
_keep_in_fp32_modules = ["query_tokens", "qformer"]
_supports_flash_attn = False # because self.qformer does not support FA2

Expand Down Expand Up @@ -1960,6 +1963,7 @@ def generate(
)
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"
_keep_in_fp32_modules = ["query_tokens", "qformer"]
_supports_flash_attn = False # because self.qformer does not support FA2

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/blt/modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def forward(
class BltPreTrainedModel(PreTrainedModel):
config: BltConfig
base_model_prefix = ""
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_modules = ["BltTransformerLayer"]
_can_compile_fullgraph = False # static cache cannot have different shapes for each layer
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
class BridgeTowerPreTrainedModel(PreTrainedModel):
config: BridgeTowerConfig
base_model_prefix = "bridgetower"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"
Expand Down Expand Up @@ -947,6 +948,7 @@ def _init_weights(self, module: nn.Module):

class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
config: BridgeTowerVisionConfig
input_modalities = "image"

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -976,6 +978,7 @@ def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
)
class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
config: BridgeTowerTextConfig
input_modalities = "text"

def __init__(self, config, add_pooling_layer=True):
r"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
class ChameleonPreTrainedModel(PreTrainedModel):
config: ChameleonConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/chinese_clip/modeling_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class ChineseCLIPPreTrainedModel(PreTrainedModel):
config: ChineseCLIPConfig
base_model_prefix = "chinese_clip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True

def _init_weights(self, module):
Expand Down Expand Up @@ -795,6 +796,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
"""

config: ChineseCLIPTextConfig
input_modalities = "text"
_no_split_modules = ["ChineseCLIPTextEmbeddings"]

def __init__(self, config, add_pooling_layer=True):
Expand Down Expand Up @@ -902,6 +904,7 @@ def forward(
class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel):
config: ChineseCLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"
_no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"]

def __init__(self, config: ChineseCLIPVisionConfig):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/clap/modeling_clap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class ClapPreTrainedModel(PreTrainedModel):
config: ClapConfig
base_model_prefix = "clap"
input_modalities = ["audio", "text"]
supports_gradient_checkpointing = False

def _init_weights(self, module: nn.Module):
Expand Down Expand Up @@ -1334,6 +1335,7 @@ def _init_weights(self, module: nn.Module):
class ClapAudioModel(ClapPreTrainedModel):
config: ClapAudioConfig
main_input_name = "input_features"
input_modalities = "audio"

def __init__(self, config: ClapAudioConfig):
super().__init__(config)
Expand Down Expand Up @@ -1406,6 +1408,7 @@ def forward(
)
class ClapTextModel(ClapPreTrainedModel):
config: ClapTextConfig
input_modalities = "text"

def __init__(self, config, add_pooling_layer=True):
r"""
Expand Down Expand Up @@ -1710,6 +1713,7 @@ def forward(
@auto_docstring
class ClapTextModelWithProjection(ClapPreTrainedModel):
config: ClapTextConfig
input_modalities = "text"

def __init__(self, config: ClapTextConfig):
super().__init__(config)
Expand Down Expand Up @@ -1776,6 +1780,7 @@ def forward(
class ClapAudioModelWithProjection(ClapPreTrainedModel):
config: ClapAudioConfig
main_input_name = "input_features"
input_modalities = "audio"

def __init__(self, config: ClapAudioConfig):
super().__init__(config)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def forward(
class CLIPPreTrainedModel(PreTrainedModel):
config: CLIPConfig
base_model_prefix = "clip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn = True
Expand Down Expand Up @@ -662,6 +663,7 @@ def forward(
)
class CLIPTextModel(CLIPPreTrainedModel):
config: CLIPTextConfig
input_modalities = "text"

_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
Expand Down Expand Up @@ -769,6 +771,7 @@ def forward(
class CLIPVisionModel(CLIPPreTrainedModel):
config: CLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"
_no_split_modules = ["CLIPEncoderLayer"]

def __init__(self, config: CLIPVisionConfig):
Expand Down Expand Up @@ -1029,6 +1032,7 @@ def forward(
@auto_docstring
class CLIPTextModelWithProjection(CLIPPreTrainedModel):
config: CLIPTextConfig
input_modalities = "text"

_supports_flash_attn = False
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
Expand Down Expand Up @@ -1099,6 +1103,7 @@ def forward(
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
config: CLIPVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
Expand Down Expand Up @@ -1169,6 +1174,7 @@ def forward(
)
class CLIPForImageClassification(CLIPPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: CLIPConfig) -> None:
super().__init__(config)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def forward(
class CLIPSegPreTrainedModel(PreTrainedModel):
config: CLIPSegConfig
base_model_prefix = "clip"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True

def _init_weights(self, module):
Expand Down Expand Up @@ -648,6 +649,7 @@ def forward(

class CLIPSegTextModel(CLIPSegPreTrainedModel):
config: CLIPSegTextConfig
input_modalities = "text"

_no_split_modules = ["CLIPSegTextEmbeddings", "CLIPSegEncoderLayer"]

Expand Down Expand Up @@ -753,6 +755,7 @@ def forward(
class CLIPSegVisionModel(CLIPSegPreTrainedModel):
config: CLIPSegVisionConfig
main_input_name = "pixel_values"
input_modalities = "image"

def __init__(self, config: CLIPSegVisionConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class Cohere2VisionCausalLMOutputWithPast(ModelOutput):
class Cohere2VisionPreTrainedModel(PreTrainedModel):
config: Cohere2VisionConfig
base_model_prefix = ""
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/colpali/modeling_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class ColPaliPreTrainedModel(PreTrainedModel):
config: ColPaliConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
_no_split_modules = []
_supports_sdpa = True
_supports_flash_attn = True
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/colqwen2/modeling_colqwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
class ColQwen2PreTrainedModel(PreTrainedModel):
config: ColQwen2Config
base_model_prefix = "model"
input_modalities = ["image", "text"]
_no_split_modules = []
_supports_sdpa = True
_supports_flash_attn = True
Expand Down
Loading