Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/qwen2_5_vl.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,15 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(

[[autodoc]] Qwen2_5_VLConfig

## Qwen2_5_VLTextConfig

[[autodoc]] Qwen2_5_VLTextConfig

## Qwen2_5_VLProcessor

[[autodoc]] Qwen2_5_VLProcessor


## Qwen2_5_VLModel

[[autodoc]] Qwen2_5_VLModel
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/qwen2_vl.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(

[[autodoc]] Qwen2VLConfig

## Qwen2VLTextConfig

[[autodoc]] Qwen2VLTextConfig

## Qwen2VLImageProcessor

[[autodoc]] Qwen2VLImageProcessor
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,12 @@
("qwen2", "Qwen2Config"),
("qwen2_5_omni", "Qwen2_5OmniConfig"),
("qwen2_5_vl", "Qwen2_5_VLConfig"),
("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
("qwen2_audio", "Qwen2AudioConfig"),
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
("qwen2_moe", "Qwen2MoeConfig"),
("qwen2_vl", "Qwen2VLConfig"),
("qwen2_vl_text", "Qwen2VLTextConfig"),
("qwen3", "Qwen3Config"),
("qwen3_moe", "Qwen3MoeConfig"),
("rag", "RagConfig"),
Expand Down Expand Up @@ -625,10 +627,12 @@
("qwen2", "Qwen2"),
("qwen2_5_omni", "Qwen2_5Omni"),
("qwen2_5_vl", "Qwen2_5_VL"),
("qwen2_5_vl_text", "Qwen2_5_VL"),
("qwen2_audio", "Qwen2Audio"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoE"),
("qwen2_vl", "Qwen2VL"),
("qwen2_vl_text", "Qwen2VL"),
("qwen3", "Qwen3"),
("qwen3_moe", "Qwen3MoE"),
("rag", "RAG"),
Expand Down Expand Up @@ -793,6 +797,8 @@
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("qwen2_5_vl_text", "qwen2_5_vl"),
("qwen2_vl_text", "qwen2_vl"),
("sam_vision_model", "sam"),
("llama4_text", "llama4"),
("blip_2_qformer", "blip_2"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,11 @@
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_5_vl", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLModel"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"),
("qwen2_vl_text", "Qwen2VLModel"),
("qwen3", "Qwen3Model"),
("qwen3_moe", "Qwen3MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ def forward(


class Qwen2_5OmniDecoderLayer(nn.Module):
def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int):
def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

Expand Down
86 changes: 70 additions & 16 deletions src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,16 @@ def __init__(
self.initializer_range = initializer_range


class Qwen2_5_VLConfig(PretrainedConfig):
class Qwen2_5_VLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.


Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
Expand Down Expand Up @@ -120,8 +119,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
Expand Down Expand Up @@ -161,20 +158,20 @@ class Qwen2_5_VLConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE

```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig

>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()

>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> model = Qwen2_5_VLTextModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
model_type = "qwen2_5_vl_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2_5_VL`
base_model_tp_plan = {
Expand Down Expand Up @@ -211,15 +208,9 @@ def __init__(
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()

self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
Expand Down Expand Up @@ -257,4 +248,67 @@ def __init__(
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


__all__ = ["Qwen2_5_VLConfig"]
class Qwen2_5_VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.


Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.

```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig

>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()

>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()

if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)

self.image_token_id = image_token_id
self.video_token_id = video_token_id

super().__init__(**kwargs)


__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"]
22 changes: 13 additions & 9 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
logging,
replace_return_docstrings,
)
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig


if is_flash_attn_available():
Expand Down Expand Up @@ -390,7 +390,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`

def _init_weights(self, module):
std = self.config.initializer_range
std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
Expand Down Expand Up @@ -566,7 +566,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.


class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, device=None):
def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
Expand Down Expand Up @@ -680,7 +680,7 @@ class Qwen2_5_VLAttention(nn.Module):
and "Generating Long Sequences with Sparse Transformers".
"""

def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None):
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
Expand Down Expand Up @@ -989,7 +989,7 @@ def forward(


class Qwen2_5_VLDecoderLayer(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int):
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

Expand Down Expand Up @@ -1077,7 +1077,9 @@ def forward(
Qwen2_5_VL_START_DOCSTRING,
)
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
def __init__(self, config: Qwen2_5_VLConfig):
config_class = Qwen2_5_VLTextConfig

def __init__(self, config: Qwen2_5_VLTextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -1497,9 +1499,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2_5_VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

text_config = config.get_text_config()
self.model = Qwen2_5_VLModel._from_config(text_config)
self.vocab_size = text_config.vocab_size
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here

# Initialize weights and apply final processing
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss

from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
PatchEmbed,
PatchMerger,
Expand Down Expand Up @@ -110,9 +110,13 @@ def __init__(
self.initializer_range = initializer_range


class Qwen2_5_VLTextConfig(Qwen2VLTextConfig):
model_type = "qwen2_5_vl_text"


class Qwen2_5_VLConfig(Qwen2VLConfig):
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}


class Qwen2_5_VLMLP(nn.Module):
Expand Down Expand Up @@ -227,7 +231,7 @@ def forward(

class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
Expand Down Expand Up @@ -971,6 +975,7 @@ def __call__(

__all__ = [
"Qwen2_5_VLConfig",
"Qwen2_5_VLTextConfig",
"Qwen2_5_VLForConditionalGeneration",
"Qwen2_5_VLModel",
"Qwen2_5_VLPreTrainedModel",
Expand Down
Loading