diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 894faad97318..9865e69c0bd8 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -721,6 +721,10 @@
title: Qwen2MoE
- local: model_doc/qwen3
title: Qwen3
+ - local: model_doc/qwen3_5
+ title: Qwen3.5
+ - local: model_doc/qwen3_5_moe
+ title: Qwen3.5 Moe
- local: model_doc/qwen3_moe
title: Qwen3MoE
- local: model_doc/qwen3_next
diff --git a/docs/source/en/model_doc/qwen3_5.md b/docs/source/en/model_doc/qwen3_5.md
new file mode 100644
index 000000000000..24111e180295
--- /dev/null
+++ b/docs/source/en/model_doc/qwen3_5.md
@@ -0,0 +1,76 @@
+
+*This model was released on 2026-01-01 and added to Hugging Face Transformers on 2026-02-09.*
+
+
+
+# Qwen3.5
+
+[Qwen3.5](https://huggingface.co/papers/2502.13923) TODO @shuaibai @bozheng
+
+Model usage
+
+
+
+
+```py
+TODO
+```
+
+
+
+
+## Qwen3_5Config
+
+[[autodoc]] Qwen3_5Config
+
+## Qwen3_5TextConfig
+
+[[autodoc]] Qwen3_5TextConfig
+
+## Qwen3_5VisionModel
+
+[[autodoc]] Qwen3_5VisionModel
+ - forward
+
+## Qwen3_5TextModel
+
+[[autodoc]] Qwen3_5TextModel
+ - forward
+
+## Qwen3_5Model
+
+[[autodoc]] Qwen3_5Model
+ - forward
+
+## Qwen3_5ForCausalLM
+
+[[autodoc]] Qwen3_5ForCausalLM
+ - forward
+
+## Qwen3_5ForConditionalGeneration
+
+[[autodoc]] Qwen3_5ForConditionalGeneration
+ - forward
+
+## Qwen3_5Tokenizer
+
+[[autodoc]] Qwen3_5Tokenizer
diff --git a/docs/source/en/model_doc/qwen3_5_moe.md b/docs/source/en/model_doc/qwen3_5_moe.md
new file mode 100644
index 000000000000..768839d5276c
--- /dev/null
+++ b/docs/source/en/model_doc/qwen3_5_moe.md
@@ -0,0 +1,72 @@
+
+*This model was released on 2026-01-01 and added to Hugging Face Transformers on 2026-02-09.*
+
+
+
+# Qwen3.5 Moe
+
+[Qwen3.5 Moe](https://huggingface.co/papers/2502.13923) TODO @shuaibai @bozheng
+
+Model usage
+
+
+
+
+```py
+TODO
+```
+
+
+
+
+## Qwen3_5MoeConfig
+
+[[autodoc]] Qwen3_5MoeConfig
+
+## Qwen3_5MoeTextConfig
+
+[[autodoc]] Qwen3_5MoeTextConfig
+
+## Qwen3_5MoeVisionModel
+
+[[autodoc]] Qwen3_5MoeVisionModel
+ - forward
+
+## Qwen3_5MoeTextModel
+
+[[autodoc]] Qwen3_5MoeTextModel
+ - forward
+
+## Qwen3_5MoeModel
+
+[[autodoc]] Qwen3_5MoeModel
+ - forward
+
+## Qwen3_5MoeForCausalLM
+
+[[autodoc]] Qwen3_5MoeForCausalLM
+ - forward
+
+## Qwen3_5MoeForConditionalGeneration
+
+[[autodoc]] Qwen3_5MoeForConditionalGeneration
+ - forward
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
old mode 100644
new mode 100755
index bba6a4b6ea34..e06c37a0a30a
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -59,6 +59,7 @@
"qwen3_omni_moe": "qwen2_moe",
"qwen3_omni_moe_thinker": "qwen2_moe",
"qwen3_next": "qwen2_moe",
+ "qwen3_5_moe": "qwen2_moe",
"hunyuan_v1_moe": "qwen2_moe",
"flex_olmo": "qwen2_moe",
"olmoe": "qwen2_moe",
@@ -70,6 +71,9 @@
def _build_checkpoint_conversion_mapping():
mapping = {
+ "qwen3_5_text": [
+ WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"),
+ ],
"t5gemma2": [
WeightRenaming(r"(?>> from transformers import Qwen3_5TextModel, Qwen3_5TextConfig
+
+ >>> # Initializing a Qwen3.5 style configuration
+ >>> configuration = Qwen3_5TextConfig()
+
+ >>> # Initializing a model from the Qwen3.5-9B style configuration
+ >>> model = Qwen3_5TextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen3_5_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=248320,
+ hidden_size=4096,
+ intermediate_size=12288,
+ num_hidden_layers=32,
+ num_attention_heads=16,
+ num_key_value_heads=4,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=256,
+ linear_conv_kernel_dim=4,
+ linear_key_head_dim=128,
+ linear_value_head_dim=128,
+ linear_num_key_heads=16,
+ linear_num_value_heads=32,
+ layer_types=None,
+ pad_token_id: int | None = None,
+ bos_token_id: int | None = None,
+ eos_token_id: int | None = None,
+ **kwargs,
+ ):
+ kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"}
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim
+ self.rope_parameters = rope_parameters
+ kwargs.setdefault("partial_rotary_factor", 0.25) # assign default for BC
+
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ interval_pattern = kwargs.get("full_attention_interval", 4)
+ self.layer_types = [
+ "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ # linear attention part
+ self.linear_conv_kernel_dim = linear_conv_kernel_dim
+ self.linear_key_head_dim = linear_key_head_dim
+ self.linear_value_head_dim = linear_value_head_dim
+ self.linear_num_key_heads = linear_num_key_heads
+ self.linear_num_value_heads = linear_num_value_heads
+ super().__init__(**kwargs)
+
+
+class Qwen3_5VisionConfig(PreTrainedConfig):
+ model_type = "qwen3_5"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+
+
+class Qwen3_5Config(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3_5Model`]. It is used to instantiate a
+ Qwen3.5 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3.5-9B-Instruct [Qwen/Qwen3.5-9B-Instruct](https://huggingface.co/Qwen/Qwen3.5-9B-Instruct).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 248056):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 248057):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 248053):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 248054):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Config
+
+ >>> # Initializing a Qwen3.5 style configuration
+ >>> configuration = Qwen3_5Config()
+
+ >>> # Initializing a model from the Qwen3.5-9B style configuration
+ >>> model = Qwen3_5ForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_5"
+ sub_configs = {"vision_config": Qwen3_5VisionConfig, "text_config": Qwen3_5TextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=248056,
+ video_token_id=248057,
+ vision_start_token_id=248053,
+ vision_end_token_id=248054,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+ super().__init__(**kwargs)
+
+
+__all__ = ["Qwen3_5Config", "Qwen3_5TextConfig"]
diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py
new file mode 100644
index 000000000000..6aecdb06570e
--- /dev/null
+++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py
@@ -0,0 +1,2194 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_5/modular_qwen3_5.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_5.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...integrations import use_kernelized_func
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+ ModelOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
+from ...utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast
+from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
+from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig
+
+
+if is_causal_conv1d_available():
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+else:
+ causal_conv1d_update, causal_conv1d_fn = None, None
+
+if is_flash_linear_attention_available():
+ from fla.modules import FusedRMSNormGated
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
+else:
+ chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
+ FusedRMSNormGated = None
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3_5DynamicCache:
+ """
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention
+ cache (which has a constant shape regardless of seq_len).
+
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
+ and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
+ For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
+ and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`.
+ """
+
+ is_compileable = False
+
+ def __init__(self, config: Qwen3_5Config):
+ super().__init__()
+ self.layer_types = config.layer_types
+ self.transformer_layers = [
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
+ ]
+ self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention")
+
+ # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
+
+ def __len__(self):
+ return len(self.layer_types)
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: dict[str, Any] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if self.key_cache[layer_idx] is None:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ for layer_idx in range(len(self.key_cache)):
+ if self.key_cache[layer_idx] is not None:
+ device = self.key_cache[layer_idx].device
+ beam_idx = beam_idx.to(device)
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx)
+
+ if self.conv_states[layer_idx] is not None:
+ device = self.conv_states[layer_idx].device
+ beam_idx = beam_idx.to(device)
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx)
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx)
+
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # take any layer that contains cache and not empty tensor
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
+ return 0
+ return self.key_cache[layer_idx].shape[-2]
+
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
+ """
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
+ the given layer at `layer_idx`.
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
+ """
+ kv_offset = 0
+ query_length = cache_position.shape[0]
+ past_seen_tokens = self.get_seq_length(layer_idx)
+ kv_length = query_length + past_seen_tokens
+ return kv_length, kv_offset
+
+ @property
+ def has_previous_state(self):
+ """We have a previous state if the last linear (conv) layer was already updated."""
+ return self.conv_states[self.last_linear_layer] is not None
+
+
+class Qwen3_5VisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Qwen3_5TextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen3_5TextConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+ self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10])
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: Qwen3_5TextConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen3_5 has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THWTHWTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+
+class Qwen3_5RMSNormGated(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states, gate=None):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ # Norm before gate
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = self.weight * hidden_states.to(input_dtype)
+ hidden_states = hidden_states * F.silu(gate.to(torch.float32))
+
+ return hidden_states.to(input_dtype)
+
+
+def apply_mask_to_padding_states(hidden_states, attention_mask):
+ """
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
+ """
+ # NOTE: attention mask is a 2D boolean tensor
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+
+ return hidden_states
+
+
+is_fast_path_available = all(
+ (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
+)
+
+
+def torch_causal_conv1d_update(
+ hidden_states,
+ conv_state,
+ weight,
+ bias=None,
+ activation=None,
+):
+ _, hidden_size, seq_len = hidden_states.shape
+ state_len = conv_state.shape[-1]
+
+ hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
+ conv_state.copy_(hidden_states_new[:, :, -state_len:])
+ out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
+ out = F.silu(out[:, :, -seq_len:])
+ out = out.to(hidden_states.dtype)
+ return out
+
+
+def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
+ """This function is intended to align with the l2norm implementation in the FLA library."""
+ inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
+ return x * inv_norm
+
+
+def torch_chunk_gated_delta_rule(
+ query,
+ key,
+ value,
+ g,
+ beta,
+ chunk_size=64,
+ initial_state=None,
+ output_final_state=False,
+ use_qk_l2norm_in_kernel=False,
+):
+ initial_dtype = query.dtype
+ if use_qk_l2norm_in_kernel:
+ query = l2norm(query, dim=-1, eps=1e-6)
+ key = l2norm(key, dim=-1, eps=1e-6)
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
+ ]
+
+ batch_size, num_heads, sequence_length, k_head_dim = key.shape
+ v_head_dim = value.shape[-1]
+ pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
+ query = F.pad(query, (0, 0, 0, pad_size))
+ key = F.pad(key, (0, 0, 0, pad_size))
+ value = F.pad(value, (0, 0, 0, pad_size))
+ beta = F.pad(beta, (0, pad_size))
+ g = F.pad(g, (0, pad_size))
+ total_sequence_length = sequence_length + pad_size
+ scale = 1 / (query.shape[-1] ** 0.5)
+ query = query * scale
+
+ v_beta = value * beta.unsqueeze(-1)
+ k_beta = key * beta.unsqueeze(-1)
+ # reshape to chunks
+ query, key, value, k_beta, v_beta = [
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
+ ]
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
+
+ # chunk decay
+ g = g.cumsum(dim=-1)
+ decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
+ attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
+ for i in range(1, chunk_size):
+ row = attn[..., i, :i].clone()
+ sub = attn[..., :i, :i].clone()
+ attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+ attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+ value = attn @ v_beta
+ k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+ last_recurrent_state = (
+ torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
+ if initial_state is None
+ else initial_state.to(value)
+ )
+ core_attn_out = torch.zeros_like(value)
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
+
+ # for each chunk
+ for i in range(0, total_sequence_length // chunk_size):
+ q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
+ attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+ v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+ v_new = v_i - v_prime
+ attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+ core_attn_out[:, :, i] = attn_inter + attn @ v_new
+ last_recurrent_state = (
+ last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
+ )
+
+ if not output_final_state:
+ last_recurrent_state = None
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
+ core_attn_out = core_attn_out[:, :, :sequence_length]
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
+ return core_attn_out, last_recurrent_state
+
+
+def torch_recurrent_gated_delta_rule(
+ query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
+):
+ initial_dtype = query.dtype
+ if use_qk_l2norm_in_kernel:
+ query = l2norm(query, dim=-1, eps=1e-6)
+ key = l2norm(key, dim=-1, eps=1e-6)
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
+ ]
+
+ batch_size, num_heads, sequence_length, k_head_dim = key.shape
+ v_head_dim = value.shape[-1]
+ scale = 1 / (query.shape[-1] ** 0.5)
+ query = query * scale
+
+ core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
+ last_recurrent_state = (
+ torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
+ if initial_state is None
+ else initial_state.to(value)
+ )
+
+ for i in range(sequence_length):
+ q_t = query[:, :, i]
+ k_t = key[:, :, i]
+ v_t = value[:, :, i]
+ g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
+ beta_t = beta[:, :, i].unsqueeze(-1)
+
+ last_recurrent_state = last_recurrent_state * g_t
+ kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
+ delta = (v_t - kv_mem) * beta_t
+ last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
+ core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
+
+ if not output_final_state:
+ last_recurrent_state = None
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
+ return core_attn_out, last_recurrent_state
+
+
+class Qwen3_5GatedDeltaNet(nn.Module):
+ def __init__(self, config: Qwen3_5Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_v_heads = config.linear_num_value_heads
+ self.num_k_heads = config.linear_num_key_heads
+ self.head_k_dim = config.linear_key_head_dim
+ self.head_v_dim = config.linear_value_head_dim
+ self.key_dim = self.head_k_dim * self.num_k_heads
+ self.value_dim = self.head_v_dim * self.num_v_heads
+
+ self.conv_kernel_size = config.linear_conv_kernel_dim
+ self.layer_idx = layer_idx
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
+ self.layer_norm_epsilon = config.rms_norm_eps
+
+ # QKV
+ self.conv_dim = self.key_dim * 2 + self.value_dim
+ self.conv1d = nn.Conv1d(
+ in_channels=self.conv_dim,
+ out_channels=self.conv_dim,
+ bias=False,
+ kernel_size=self.conv_kernel_size,
+ groups=self.conv_dim,
+ padding=self.conv_kernel_size - 1,
+ )
+
+ # time step projection (discretization)
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
+
+ A = torch.empty(self.num_v_heads).uniform_(0, 16)
+ self.A_log = nn.Parameter(torch.log(A))
+
+ self.norm = (
+ Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
+ if FusedRMSNormGated is None
+ else FusedRMSNormGated(
+ self.head_v_dim,
+ eps=self.layer_norm_epsilon,
+ activation=self.activation,
+ device=torch.cuda.current_device(),
+ dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
+ )
+ )
+
+ self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
+
+ self.causal_conv1d_fn = causal_conv1d_fn
+ self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
+ self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
+ self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
+
+ if not is_fast_path_available:
+ logger.warning_once(
+ "The fast path is not available because one of the required library is not installed. Falling back to "
+ "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
+ " https://github.com/Dao-AILab/causal-conv1d"
+ )
+
+ self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False)
+ self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
+ self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+ self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Qwen3_5DynamicCache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ ):
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
+
+ # Set up dimensions for reshapes later
+ batch_size, seq_len, _ = hidden_states.shape
+
+ use_precomputed_states = (
+ cache_params is not None
+ and cache_params.has_previous_state
+ and seq_len == 1
+ and cache_position is not None
+ )
+
+ # getting projected states from cache if it exists
+ if cache_params is not None:
+ conv_state = cache_params.conv_states[self.layer_idx]
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
+
+ mixed_qkv = self.in_proj_qkv(hidden_states)
+ mixed_qkv = mixed_qkv.transpose(1, 2)
+
+ z = self.in_proj_z(hidden_states)
+ z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
+
+ b = self.in_proj_b(hidden_states)
+ a = self.in_proj_a(hidden_states)
+
+ if use_precomputed_states:
+ # 2. Convolution sequence transformation
+ # NOTE: the conv state is updated in `causal_conv1d_update`
+ mixed_qkv = self.causal_conv1d_update(
+ mixed_qkv,
+ conv_state,
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.activation,
+ )
+ else:
+ if cache_params is not None:
+ conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
+ cache_params.conv_states[self.layer_idx] = conv_state
+ if self.causal_conv1d_fn is not None:
+ mixed_qkv = self.causal_conv1d_fn(
+ x=mixed_qkv,
+ weight=self.conv1d.weight.squeeze(1),
+ bias=self.conv1d.bias,
+ activation=self.activation,
+ seq_idx=None,
+ )
+ else:
+ mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
+
+ mixed_qkv = mixed_qkv.transpose(1, 2)
+ query, key, value = torch.split(
+ mixed_qkv,
+ [
+ self.key_dim,
+ self.key_dim,
+ self.value_dim,
+ ],
+ dim=-1,
+ )
+
+ query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
+ key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
+ value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
+
+ beta = b.sigmoid()
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
+ if self.num_v_heads // self.num_k_heads > 1:
+ query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
+ key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
+
+ if not use_precomputed_states:
+ core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ initial_state=None,
+ output_final_state=cache_params is not None,
+ use_qk_l2norm_in_kernel=True,
+ )
+
+ else:
+ core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ initial_state=recurrent_state,
+ output_final_state=cache_params is not None,
+ use_qk_l2norm_in_kernel=True,
+ )
+
+ # Update cache
+ if cache_params is not None:
+ cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
+
+ # reshape input data into 2D tensor
+ core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
+ z = z.reshape(-1, self.head_v_dim)
+ core_attn_out = self.norm(core_attn_out, z)
+ core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
+
+ output = self.out_proj(core_attn_out)
+ return output
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Removes the interleaving of cos and sin from GLM
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+@use_kernelized_func(apply_rotary_pos_emb)
+class Qwen3_5Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Qwen3_5Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
+ self.k_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_values: Cache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states, gate = torch.chunk(
+ self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
+ )
+ gate = gate.reshape(*input_shape, -1)
+
+ query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output * torch.sigmoid(gate)
+
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen3_5MLP(nn.Module):
+ def __init__(self, config: Qwen3_5Config, intermediate_size: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Qwen3_5RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ # Llama does x.to(float16) * w whilst Qwen3_5 is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = output * (1.0 + self.weight.float())
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class Qwen3_5DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen3_5TextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_type = config.layer_types[layer_idx]
+ if self.layer_type == "linear_attention":
+ self.linear_attn = Qwen3_5GatedDeltaNet(config, layer_idx)
+ elif self.layer_type == "full_attention":
+ self.self_attn = Qwen3_5Attention(config, layer_idx)
+ self.mlp = Qwen3_5MLP(config, config.intermediate_size)
+ self.input_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Token Mixer
+ if self.layer_type == "linear_attention":
+ hidden_states = self.linear_attn(
+ hidden_states=hidden_states,
+ cache_params=past_key_values,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ )
+ elif self.layer_type == "full_attention":
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Qwen3_5PreTrainedModel(PreTrainedModel):
+ config: Qwen3_5Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen3_5DecoderLayer", "Qwen3_5VisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _keys_to_ignore_on_load_unexpected = [r"^mtp.*"]
+ _can_record_outputs = {
+ "hidden_states": Qwen3_5DecoderLayer,
+ "attentions": Qwen3_5Attention,
+ }
+ _is_stateful = True
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, Qwen3_5GatedDeltaNet):
+ init.ones_(module.dt_bias)
+ init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_())
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
+ elif isinstance(module, Qwen3_5RMSNorm):
+ init.zeros_(module.weight)
+ elif isinstance(module, Qwen3_5VisionRotaryEmbedding):
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
+ init.copy_(module.inv_freq, inv_freq)
+
+
+class Qwen3_5VisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3_5VisionPatchEmbed(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3_5VisionPatchMerger(nn.Module):
+ def __init__(self, config: Qwen3_5VisionConfig, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class Qwen3_5VisionAttention(nn.Module):
+ def __init__(self, config: Qwen3_5VisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor | None = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ if is_flash_attention_requested(self.config):
+ # Flash Attention: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen3_5VisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen3_5VisionAttention(config=config)
+ self.mlp = Qwen3_5VisionMLP(config=config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor | None = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3_5VisionModel(Qwen3_5PreTrainedModel):
+ config: Qwen3_5VisionConfig
+ _no_split_modules = ["Qwen3_5VisionBlock"]
+ _can_record_outputs = {
+ "hidden_states": Qwen3_5VisionBlock,
+ "attentions": Qwen3_5VisionAttention,
+ }
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen3_5VisionPatchEmbed(
+ config=config,
+ )
+
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3_5VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen3_5VisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen3_5VisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=False,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.post_init()
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ device = freq_table.device
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device) # block row indices
+ block_cols = torch.arange(merged_w, device=device) # block col indices
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+ device = self.pos_embed.weight.device
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
+ weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
+ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.config.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ @check_model_inputs
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ merged_hidden_states = self.merger(hidden_states)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=hidden_states,
+ pooler_output=merged_hidden_states,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Qwen3_5ModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ rope_deltas: torch.LongTensor | None = None
+
+
+class Qwen3_5TextModel(Qwen3_5PreTrainedModel):
+ def __init__(self, config: Qwen3_5TextConfig):
+ super().__init__(config)
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
+ self.layers = nn.ModuleList(
+ [Qwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen3_5TextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = Qwen3_5DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # mrope: the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = position_ids[0]
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+ linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=layer_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return Qwen3_5ModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
+ """
+ NOTE: Left-padding is used for linear attention mask.
+ No need for zeroing states when
+ 1. Cached forward
+ 2. Attending to all inputs
+ """
+ linear_attn_mask = attention_mask
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
+ linear_attn_mask = None
+ return linear_attn_mask
+
+
+@auto_docstring
+class Qwen3_5Model(Qwen3_5PreTrainedModel):
+ base_model_prefix = "model"
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3_5Config
+ _no_split_modules = ["Qwen3_5TextDecoderLayer", "Qwen3_5VisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen3_5VisionModel._from_config(config.vision_config)
+ self.language_model = Qwen3_5TextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Different from the original implementation, Qwen3_5 use timestamps rather than absolute time position ids."""
+
+ # Since we use timestamps to separate videos, like , the video_grid_thw should also be split
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
+ video_grid_thw[:, 0] = 1
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ @can_return_tuple
+ @auto_docstring
+ def get_video_features(
+ self,
+ pixel_values_videos: torch.FloatTensor,
+ video_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ # Same implementation as for images
+ return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ vision_output: BaseModelOutputWithPooling = self.visual(
+ pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs
+ )
+ image_embeds = vision_output.pooler_output
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ vision_output.pooler_output = image_embeds
+
+ return vision_output
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: torch.FloatTensor | None = None,
+ video_features: torch.FloatTensor | None = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_video_mask].numel() == video_features.numel(),
+ f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
+ )
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Qwen3_5ModelOutputWithPast:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_outputs: BaseModelOutputWithPooling = self.get_image_features(
+ pixel_values, image_grid_thw, return_dict=True
+ )
+ image_embeds = image_outputs.pooler_output
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_outputs: BaseModelOutputWithPooling = self.get_video_features(
+ pixel_values_videos, video_grid_thw, return_dict=True
+ )
+ video_embeds = video_outputs.pooler_output
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if position_ids is None:
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
+ if self.rope_deltas is None or past_key_values_length == 0:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Qwen3_5ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ rope_deltas=self.rope_deltas,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class Qwen3_5ForCausalLM(Qwen3_5PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _tp_plan = {"lm_head": "colwise_gather_output"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ config: Qwen3_5TextConfig
+ _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3_5TextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen3_5ForCausalLM
+
+ >>> model = Qwen3_5ForCausalLM.from_pretrained("Qwen/Qwen3_5-8B")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3_5-8B")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen3_5 causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen3_5CausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ rope_deltas: torch.LongTensor | None = None
+
+
+class Qwen3_5ForConditionalGeneration(Qwen3_5PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3_5Config
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3_5Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ @auto_docstring
+ def get_video_features(
+ self,
+ pixel_values_videos: torch.FloatTensor,
+ video_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ return self.model.get_video_features(
+ pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
+ )
+
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Qwen3_5CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
+
+ >>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
+ },
+ {"type": "text", "text": "Describe the image."},
+ ],
+ }
+ ]
+
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ )
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
+ >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+ >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ >>> print(output_text)
+ ```
+ """
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return Qwen3_5CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ is_first_iteration=False,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ is_first_iteration=is_first_iteration,
+ **kwargs,
+ )
+
+ # Qwen3_5 position_ids are prepared with rope_deltas
+ if position_ids is None:
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
+ vision_positions, rope_deltas = self.model.get_rope_index(
+ model_inputs.get("input_ids", None),
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.model.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ elif "position_ids" in model_inputs:
+ batch_size, seq_length = model_inputs["position_ids"].shape
+ device = model_inputs["position_ids"].device
+ position_ids = torch.arange(seq_length, device=device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ delta = cache_position[0] + self.model.rope_deltas
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ vision_positions = position_ids + delta.expand_as(position_ids)
+
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
+ text_positions = model_inputs["position_ids"][None, ...]
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
+
+ if not is_first_iteration and use_cache:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: torch.LongTensor | None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: torch.LongTensor | None = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Qwen3_5 use timestamps and remove second_per_grid_ts
+ # Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ # video_nums: (batch_size,)
+ # since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
+ # but Qwen3_5 append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
+ if video_grid_thw is not None:
+ cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
+ cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
+ # Find video boundaries in cumulative_frame_counts
+ video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
+ # example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
+ video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = [
+ "Qwen3_5VisionModel",
+ "Qwen3_5TextModel",
+ "Qwen3_5Model",
+ "Qwen3_5ForCausalLM",
+ "Qwen3_5ForConditionalGeneration",
+ "Qwen3_5PreTrainedModel",
+]
diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py
new file mode 100644
index 000000000000..9f61318c5740
--- /dev/null
+++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py
@@ -0,0 +1,841 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen3.5 model."""
+
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ... import initialization as init
+from ...cache_utils import Cache
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
+from ...modeling_rope_utils import RopeParameters
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.generic import check_model_inputs
+from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM
+from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig
+from ..qwen3_next.modeling_qwen3_next import (
+ Qwen3NextAttention,
+ Qwen3NextDynamicCache,
+ Qwen3NextGatedDeltaNet,
+ Qwen3NextMLP,
+ Qwen3NextModel,
+ Qwen3NextPreTrainedModel,
+ Qwen3NextRMSNorm,
+ apply_mask_to_padding_states,
+)
+from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
+from ..qwen3_vl.modeling_qwen3_vl import (
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLModel,
+ Qwen3VLModelOutputWithPast,
+ Qwen3VLTextRotaryEmbedding,
+ Qwen3VLVisionModel,
+ Qwen3VLVisionRotaryEmbedding,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3_5TextConfig(Qwen3NextConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3_5TextModel`]. It is used to instantiate a
+ Qwen3_5 model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of
+ Qwen3.5-9B-Instruct [Qwen/Qwen3.5-9B-Instruct](https://huggingface.co/Qwen/Qwen3.5-9B-Instruct).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 248320):
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
+ `inputs_ids`.
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 12288):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 256):
+ Projection weights dimension in multi-head attention.
+ linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
+ Kernel size of the convolution used in linear attention layers.
+ linear_key_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of each key head in linear attention.
+ linear_value_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of each value head in linear attention.
+ linear_num_key_heads (`int`, *optional*, defaults to 16):
+ Number of key heads used in linear attention layers.
+ linear_num_value_heads (`int`, *optional*, defaults to 32):
+ Number of value heads used in linear attention layers.
+ layer_types (`list[str]`, *optional*):
+ Types of each layer (attention or linear).
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*):
+ End of stream token id.
+
+ ```python
+ >>> from transformers import Qwen3_5TextModel, Qwen3_5TextConfig
+
+ >>> # Initializing a Qwen3.5 style configuration
+ >>> configuration = Qwen3_5TextConfig()
+
+ >>> # Initializing a model from the Qwen3.5-9B style configuration
+ >>> model = Qwen3_5TextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen3_5_text"
+ base_config_key = "text_config"
+
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+
+ def __init__(
+ self,
+ vocab_size=248320,
+ hidden_size=4096,
+ intermediate_size=12288,
+ num_hidden_layers=32,
+ num_attention_heads=16,
+ num_key_value_heads=4,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=256,
+ linear_conv_kernel_dim=4,
+ linear_key_head_dim=128,
+ linear_value_head_dim=128,
+ linear_num_key_heads=16,
+ linear_num_value_heads=32,
+ layer_types=None,
+ pad_token_id: int | None = None,
+ bos_token_id: int | None = None,
+ eos_token_id: int | None = None,
+ **kwargs,
+ ):
+ kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"}
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ del self.decoder_sparse_step
+ del self.norm_topk_prob
+ del self.mlp_only_layers
+ del self.moe_intermediate_size
+ del self.shared_expert_intermediate_size
+ del self.num_experts_per_tok
+ del self.num_experts
+ del self.output_router_logits
+ del self.router_aux_loss_coef
+
+
+class Qwen3_5VisionConfig(Qwen3VLVisionConfig):
+ model_type = "qwen3_5"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ del self.deepstack_visual_indexes
+
+
+class Qwen3_5Config(Qwen3VLConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3_5Model`]. It is used to instantiate a
+ Qwen3.5 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3.5-9B-Instruct [Qwen/Qwen3.5-9B-Instruct](https://huggingface.co/Qwen/Qwen3.5-9B-Instruct).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 248056):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 248057):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 248053):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 248054):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Config
+
+ >>> # Initializing a Qwen3.5 style configuration
+ >>> configuration = Qwen3_5Config()
+
+ >>> # Initializing a model from the Qwen3.5-9B style configuration
+ >>> model = Qwen3_5ForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_5"
+ sub_configs = {"vision_config": Qwen3_5VisionConfig, "text_config": Qwen3_5TextConfig}
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=248056,
+ video_token_id=248057,
+ vision_start_token_id=248053,
+ vision_end_token_id=248054,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ super().__init__(
+ text_config=text_config,
+ vision_config=vision_config,
+ image_token_id=image_token_id,
+ video_token_id=video_token_id,
+ vision_start_token_id=vision_start_token_id,
+ vision_end_token_id=vision_end_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class Qwen3_5DynamicCache(Qwen3NextDynamicCache):
+ pass
+
+
+class Qwen3_5VisionRotaryEmbedding(Qwen3VLVisionRotaryEmbedding):
+ pass
+
+
+class Qwen3_5TextRotaryEmbedding(Qwen3VLTextRotaryEmbedding):
+ def __init__(self, config: Qwen3_5TextConfig, device=None):
+ super().__init__()
+ self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10])
+
+ def compute_default_rope_parameters(
+ config: Qwen3_5TextConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+
+class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
+ def __init__(self, config: Qwen3_5Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+ del projection_size_qkvz # noqa: F821
+ del projection_size_ba # noqa: F821
+ del self.in_proj_qkvz
+ del self.in_proj_ba
+
+ self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False)
+ self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
+ self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+ self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+
+ def fix_query_key_value_ordering(self):
+ raise AttributeError("Not needed for Qwen3.5 Series")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Qwen3_5DynamicCache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ ):
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
+
+ # Set up dimensions for reshapes later
+ batch_size, seq_len, _ = hidden_states.shape
+
+ use_precomputed_states = (
+ cache_params is not None
+ and cache_params.has_previous_state
+ and seq_len == 1
+ and cache_position is not None
+ )
+
+ # getting projected states from cache if it exists
+ if cache_params is not None:
+ conv_state = cache_params.conv_states[self.layer_idx]
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
+
+ mixed_qkv = self.in_proj_qkv(hidden_states)
+ mixed_qkv = mixed_qkv.transpose(1, 2)
+
+ z = self.in_proj_z(hidden_states)
+ z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
+
+ b = self.in_proj_b(hidden_states)
+ a = self.in_proj_a(hidden_states)
+
+ if use_precomputed_states:
+ # 2. Convolution sequence transformation
+ # NOTE: the conv state is updated in `causal_conv1d_update`
+ mixed_qkv = self.causal_conv1d_update(
+ mixed_qkv,
+ conv_state,
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.activation,
+ )
+ else:
+ if cache_params is not None:
+ conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
+ cache_params.conv_states[self.layer_idx] = conv_state
+ if self.causal_conv1d_fn is not None:
+ mixed_qkv = self.causal_conv1d_fn(
+ x=mixed_qkv,
+ weight=self.conv1d.weight.squeeze(1),
+ bias=self.conv1d.bias,
+ activation=self.activation,
+ seq_idx=None,
+ )
+ else:
+ mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
+
+ mixed_qkv = mixed_qkv.transpose(1, 2)
+ query, key, value = torch.split(
+ mixed_qkv,
+ [
+ self.key_dim,
+ self.key_dim,
+ self.value_dim,
+ ],
+ dim=-1,
+ )
+
+ query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
+ key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
+ value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
+
+ beta = b.sigmoid()
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
+ if self.num_v_heads // self.num_k_heads > 1:
+ query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
+ key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
+
+ if not use_precomputed_states:
+ core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ initial_state=None,
+ output_final_state=cache_params is not None,
+ use_qk_l2norm_in_kernel=True,
+ )
+
+ else:
+ core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ initial_state=recurrent_state,
+ output_final_state=cache_params is not None,
+ use_qk_l2norm_in_kernel=True,
+ )
+
+ # Update cache
+ if cache_params is not None:
+ cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
+
+ # reshape input data into 2D tensor
+ core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
+ z = z.reshape(-1, self.head_v_dim)
+ core_attn_out = self.norm(core_attn_out, z)
+ core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
+
+ output = self.out_proj(core_attn_out)
+ return output
+
+
+class Qwen3_5Attention(Qwen3NextAttention):
+ pass
+
+
+class Qwen3_5MLP(Qwen3NextMLP):
+ def __init__(self, config: Qwen3_5Config, intermediate_size: int):
+ super().__init__(config, intermediate_size)
+ self.intermediate_size = intermediate_size
+
+
+class Qwen3_5RMSNorm(Qwen3NextRMSNorm):
+ pass
+
+
+class Qwen3_5DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen3_5TextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_type = config.layer_types[layer_idx]
+ if self.layer_type == "linear_attention":
+ self.linear_attn = Qwen3_5GatedDeltaNet(config, layer_idx)
+ elif self.layer_type == "full_attention":
+ self.self_attn = Qwen3_5Attention(config, layer_idx)
+ self.mlp = Qwen3_5MLP(config, config.intermediate_size)
+ self.input_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Token Mixer
+ if self.layer_type == "linear_attention":
+ hidden_states = self.linear_attn(
+ hidden_states=hidden_states,
+ cache_params=past_key_values,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ )
+ elif self.layer_type == "full_attention":
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Qwen3_5PreTrainedModel(Qwen3NextPreTrainedModel):
+ config: Qwen3_5Config
+ _no_split_modules = ["Qwen3_5DecoderLayer", "Qwen3_5VisionBlock"]
+ _can_record_outputs = {
+ "hidden_states": Qwen3_5DecoderLayer,
+ "attentions": Qwen3_5Attention,
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, Qwen3_5GatedDeltaNet):
+ init.ones_(module.dt_bias)
+ init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_())
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
+ elif isinstance(module, Qwen3_5RMSNorm):
+ init.zeros_(module.weight)
+ elif isinstance(module, Qwen3_5VisionRotaryEmbedding):
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
+ init.copy_(module.inv_freq, inv_freq)
+
+
+class Qwen3_5VisionModel(Qwen3VLVisionModel):
+ config: Qwen3_5VisionConfig
+ _no_split_modules = ["Qwen3_5VisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ del self.deepstack_visual_indexes
+ del self.deepstack_merger_list
+
+ @check_model_inputs
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ merged_hidden_states = self.merger(hidden_states)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=hidden_states,
+ pooler_output=merged_hidden_states,
+ )
+
+
+class Qwen3_5ModelOutputWithPast(Qwen3VLModelOutputWithPast):
+ pass
+
+
+class Qwen3_5TextModel(Qwen3NextModel):
+ def __init__(self, config: Qwen3_5TextConfig):
+ super().__init__(config)
+ self.rotary_emb = Qwen3_5TextRotaryEmbedding(config=config)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = Qwen3_5DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # mrope: the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = position_ids[0]
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+ linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=layer_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return Qwen3_5ModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class Qwen3_5Model(Qwen3VLModel):
+ def get_video_features(
+ self,
+ **super_kwargs,
+ ) -> tuple | BaseModelOutputWithPooling:
+ # Same implementation as for images
+ return super().get_video_features(**super_kwargs)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ pixel_values = pixel_values.type(self.visual.dtype)
+ vision_output: BaseModelOutputWithPooling = self.visual(
+ pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs
+ )
+ image_embeds = vision_output.pooler_output
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ vision_output.pooler_output = image_embeds
+
+ return vision_output
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Qwen3_5ModelOutputWithPast:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_outputs: BaseModelOutputWithPooling = self.get_image_features(
+ pixel_values, image_grid_thw, return_dict=True
+ )
+ image_embeds = image_outputs.pooler_output
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_outputs: BaseModelOutputWithPooling = self.get_video_features(
+ pixel_values_videos, video_grid_thw, return_dict=True
+ )
+ video_embeds = video_outputs.pooler_output
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if position_ids is None:
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
+ if self.rope_deltas is None or past_key_values_length == 0:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Qwen3_5ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ rope_deltas=self.rope_deltas,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Qwen3_5ForCausalLM(Qwen3ForCausalLM):
+ config: Qwen3_5TextConfig
+ _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3_5TextModel(config)
+
+
+class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration):
+ def get_video_features(
+ self,
+ **super_kwargs,
+ ) -> tuple | BaseModelOutputWithPooling:
+ return super().get_video_features(**super_kwargs)
+
+ def get_image_features(
+ self,
+ **super_kwargs,
+ ) -> tuple | BaseModelOutputWithPooling:
+ return super().get_image_features(**super_kwargs)
+
+
+__all__ = [
+ "Qwen3_5Config",
+ "Qwen3_5TextConfig",
+ "Qwen3_5VisionModel",
+ "Qwen3_5TextModel",
+ "Qwen3_5Model",
+ "Qwen3_5ForCausalLM",
+ "Qwen3_5ForConditionalGeneration",
+ "Qwen3_5PreTrainedModel",
+]
diff --git a/src/transformers/models/qwen3_5/tokenization_qwen3_5.py b/src/transformers/models/qwen3_5/tokenization_qwen3_5.py
new file mode 100644
index 000000000000..28049a4deb4c
--- /dev/null
+++ b/src/transformers/models/qwen3_5/tokenization_qwen3_5.py
@@ -0,0 +1,94 @@
+# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Qwen3.5."""
+
+from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers
+from tokenizers.models import BPE
+
+from ...tokenization_utils_tokenizers import TokenizersBackend
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+|\p{N}| ?[^\s\p{L}\p{M}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
+
+
+class Qwen3_5Tokenizer(TokenizersBackend):
+ model_input_names = ["input_ids", "attention_mask"]
+ model = BPE
+
+ def __init__(
+ self,
+ vocab: str | dict[str, int] | None = None,
+ merges: str | list[str] | None = None,
+ vocab_file=None,
+ merges_file=None,
+ unk_token: str = "<|endoftext|>",
+ bos_token=None,
+ eos_token: str = "<|endoftext|>",
+ pad_token: str = "<|endoftext|>",
+ add_prefix_space=None,
+ **kwargs,
+ ):
+ self.add_prefix_space = add_prefix_space if add_prefix_space is not None else False
+ self._vocab = (
+ vocab
+ if vocab is not None
+ else {
+ "<|endoftext|>": 0,
+ }
+ )
+ self._merges = merges or []
+ self._tokenizer = Tokenizer(
+ BPE(
+ vocab=self._vocab,
+ merges=self._merges,
+ dropout=None,
+ unk_token=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ byte_fallback=False,
+ )
+ )
+ self._tokenizer.decoder = decoders.ByteLevel()
+ self._tokenizer.normalizer = normalizers.NFC()
+ self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(
+ Regex(PRETOKENIZE_REGEX),
+ behavior="isolated",
+ invert=False,
+ ),
+ pre_tokenizers.ByteLevel(
+ add_prefix_space=self.add_prefix_space,
+ use_regex=False,
+ ),
+ ]
+ )
+
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+
+__all__ = ["Qwen3_5Tokenizer"]
diff --git a/src/transformers/models/qwen3_5_moe/__init__.py b/src/transformers/models/qwen3_5_moe/__init__.py
new file mode 100644
index 000000000000..fabf00e524e6
--- /dev/null
+++ b/src/transformers/models/qwen3_5_moe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_qwen3_5_moe import *
+ from .modeling_qwen3_5_moe import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py
new file mode 100644
index 000000000000..13ee7ba77a42
--- /dev/null
+++ b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py
@@ -0,0 +1,330 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_5_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PreTrainedConfig, layer_type_validation
+from ...modeling_rope_utils import RopeParameters
+
+
+class Qwen3_5MoeTextConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3_5MoeTextModel`]. It is used to instantiate a
+ Qwen3.5-MoE model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of
+ Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 248320):
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
+ `inputs_ids`.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 40):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 256):
+ Projection weights dimension in multi-head attention.
+ linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
+ Kernel size of the convolution used in linear attention layers.
+ linear_key_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of each key head in linear attention.
+ linear_value_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of each value head in linear attention.
+ linear_num_key_heads (`int`, *optional*, defaults to 16):
+ Number of key heads used in linear attention layers.
+ linear_num_value_heads (`int`, *optional*, defaults to 32):
+ Number of value heads used in linear attention layers.
+ moe_intermediate_size (`int`, *optional*, defaults to 512):
+ Intermediate size of the routed expert.
+ shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
+ Intermediate size of the shared expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ Number of selected experts.
+ num_experts (`int`, *optional*, defaults to 256):
+ Number of routed experts.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+ layer_types (`list[str]`, *optional*):
+ Types of each layer (attention or linear).
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*):
+ End of stream token id.
+
+ ```python
+ >>> from transformers import Qwen3_5MoeTextModel, Qwen3_5MoeTextConfig
+
+ >>> # Initializing a Qwen3.5-MoE style configuration
+ >>> configuration = Qwen3_5MoeTextConfig()
+
+ >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration
+ >>> model = Qwen3_5MoeTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen3_5_moe_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
+ "layers.*.mlp.experts.down_proj": "rowwise",
+ "layers.*.mlp.shared_expert.gate_proj": "colwise",
+ "layers.*.mlp.shared_expert.up_proj": "colwise",
+ "layers.*.mlp.shared_expert.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=248320,
+ hidden_size=2048,
+ num_hidden_layers=40,
+ num_attention_heads=16,
+ num_key_value_heads=2,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=256,
+ linear_conv_kernel_dim=4,
+ linear_key_head_dim=128,
+ linear_value_head_dim=128,
+ linear_num_key_heads=16,
+ linear_num_value_heads=32,
+ moe_intermediate_size=512,
+ shared_expert_intermediate_size=512,
+ num_experts_per_tok=8,
+ num_experts=256,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ layer_types=None,
+ pad_token_id: int | None = None,
+ bos_token_id: int | None = None,
+ eos_token_id: int | None = None,
+ **kwargs,
+ ):
+ kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"}
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim
+ self.rope_parameters = rope_parameters
+ kwargs.setdefault("partial_rotary_factor", 0.25) # assign default for BC
+
+ self.layer_types = layer_types
+ if self.layer_types is None:
+ interval_pattern = kwargs.get("full_attention_interval", 4)
+ self.layer_types = [
+ "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ # linear attention part
+ self.linear_conv_kernel_dim = linear_conv_kernel_dim
+ self.linear_key_head_dim = linear_key_head_dim
+ self.linear_value_head_dim = linear_value_head_dim
+ self.linear_num_key_heads = linear_num_key_heads
+ self.linear_num_value_heads = linear_num_value_heads
+ self.moe_intermediate_size = moe_intermediate_size
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ super().__init__(**kwargs)
+
+
+class Qwen3_5MoeVisionConfig(PreTrainedConfig):
+ model_type = "qwen3_5_moe"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+
+
+class Qwen3_5MoeConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3_5MoeModel`]. It is used to instantiate a
+ Qwen3.5-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 248056):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 248057):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 248053):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 248054):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeConfig
+
+ >>> # Initializing a Qwen3.5-MoE style configuration
+ >>> configuration = Qwen3_5MoeConfig()
+
+ >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration
+ >>> model = Qwen3_5MoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_5_moe"
+ sub_configs = {"vision_config": Qwen3_5MoeVisionConfig, "text_config": Qwen3_5MoeTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=248056,
+ video_token_id=248057,
+ vision_start_token_id=248053,
+ vision_end_token_id=248054,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+ super().__init__(**kwargs)
+
+
+__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig"]
diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
new file mode 100644
index 000000000000..09f64ce756c6
--- /dev/null
+++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
@@ -0,0 +1,2414 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_5_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...integrations import use_experts_implementation, use_kernelized_func
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ ModelOutput,
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
+from ...utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast
+from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
+from ...utils.output_capturing import OutputRecorder
+from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig, Qwen3_5MoeVisionConfig
+
+
+if is_causal_conv1d_available():
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+else:
+ causal_conv1d_update, causal_conv1d_fn = None, None
+
+if is_flash_linear_attention_available():
+ from fla.modules import FusedRMSNormGated
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
+else:
+ chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
+ FusedRMSNormGated = None
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3_5MoeVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Qwen3_5MoeTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen3_5MoeTextConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+ self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10])
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: Qwen3_5MoeTextConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen3_5Moe has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THWTHWTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+
+class Qwen3_5MoeDynamicCache:
+ """
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention
+ cache (which has a constant shape regardless of seq_len).
+
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
+ and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
+ For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
+ and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`.
+ """
+
+ is_compileable = False
+
+ def __init__(self, config: Qwen3_5MoeConfig):
+ super().__init__()
+ self.layer_types = config.layer_types
+ self.transformer_layers = [
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
+ ]
+ self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention")
+
+ # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
+
+ def __len__(self):
+ return len(self.layer_types)
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: dict[str, Any] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if self.key_cache[layer_idx] is None:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ for layer_idx in range(len(self.key_cache)):
+ if self.key_cache[layer_idx] is not None:
+ device = self.key_cache[layer_idx].device
+ beam_idx = beam_idx.to(device)
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx)
+
+ if self.conv_states[layer_idx] is not None:
+ device = self.conv_states[layer_idx].device
+ beam_idx = beam_idx.to(device)
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx)
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx)
+
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # take any layer that contains cache and not empty tensor
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
+ return 0
+ return self.key_cache[layer_idx].shape[-2]
+
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
+ """
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
+ the given layer at `layer_idx`.
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
+ """
+ kv_offset = 0
+ query_length = cache_position.shape[0]
+ past_seen_tokens = self.get_seq_length(layer_idx)
+ kv_length = query_length + past_seen_tokens
+ return kv_length, kv_offset
+
+ @property
+ def has_previous_state(self):
+ """We have a previous state if the last linear (conv) layer was already updated."""
+ return self.conv_states[self.last_linear_layer] is not None
+
+
+class Qwen3_5MoeRMSNormGated(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states, gate=None):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ # Norm before gate
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = self.weight * hidden_states.to(input_dtype)
+ hidden_states = hidden_states * F.silu(gate.to(torch.float32))
+
+ return hidden_states.to(input_dtype)
+
+
+def apply_mask_to_padding_states(hidden_states, attention_mask):
+ """
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
+ """
+ # NOTE: attention mask is a 2D boolean tensor
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+
+ return hidden_states
+
+
+is_fast_path_available = all(
+ (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
+)
+
+
+def torch_causal_conv1d_update(
+ hidden_states,
+ conv_state,
+ weight,
+ bias=None,
+ activation=None,
+):
+ _, hidden_size, seq_len = hidden_states.shape
+ state_len = conv_state.shape[-1]
+
+ hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
+ conv_state.copy_(hidden_states_new[:, :, -state_len:])
+ out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
+ out = F.silu(out[:, :, -seq_len:])
+ out = out.to(hidden_states.dtype)
+ return out
+
+
+def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
+ """This function is intended to align with the l2norm implementation in the FLA library."""
+ inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
+ return x * inv_norm
+
+
+def torch_chunk_gated_delta_rule(
+ query,
+ key,
+ value,
+ g,
+ beta,
+ chunk_size=64,
+ initial_state=None,
+ output_final_state=False,
+ use_qk_l2norm_in_kernel=False,
+):
+ initial_dtype = query.dtype
+ if use_qk_l2norm_in_kernel:
+ query = l2norm(query, dim=-1, eps=1e-6)
+ key = l2norm(key, dim=-1, eps=1e-6)
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
+ ]
+
+ batch_size, num_heads, sequence_length, k_head_dim = key.shape
+ v_head_dim = value.shape[-1]
+ pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
+ query = F.pad(query, (0, 0, 0, pad_size))
+ key = F.pad(key, (0, 0, 0, pad_size))
+ value = F.pad(value, (0, 0, 0, pad_size))
+ beta = F.pad(beta, (0, pad_size))
+ g = F.pad(g, (0, pad_size))
+ total_sequence_length = sequence_length + pad_size
+ scale = 1 / (query.shape[-1] ** 0.5)
+ query = query * scale
+
+ v_beta = value * beta.unsqueeze(-1)
+ k_beta = key * beta.unsqueeze(-1)
+ # reshape to chunks
+ query, key, value, k_beta, v_beta = [
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
+ ]
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
+
+ # chunk decay
+ g = g.cumsum(dim=-1)
+ decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
+ attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
+ for i in range(1, chunk_size):
+ row = attn[..., i, :i].clone()
+ sub = attn[..., :i, :i].clone()
+ attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+ attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+ value = attn @ v_beta
+ k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
+ last_recurrent_state = (
+ torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
+ if initial_state is None
+ else initial_state.to(value)
+ )
+ core_attn_out = torch.zeros_like(value)
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
+
+ # for each chunk
+ for i in range(0, total_sequence_length // chunk_size):
+ q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
+ attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
+ v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+ v_new = v_i - v_prime
+ attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+ core_attn_out[:, :, i] = attn_inter + attn @ v_new
+ last_recurrent_state = (
+ last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
+ )
+
+ if not output_final_state:
+ last_recurrent_state = None
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
+ core_attn_out = core_attn_out[:, :, :sequence_length]
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
+ return core_attn_out, last_recurrent_state
+
+
+def torch_recurrent_gated_delta_rule(
+ query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
+):
+ initial_dtype = query.dtype
+ if use_qk_l2norm_in_kernel:
+ query = l2norm(query, dim=-1, eps=1e-6)
+ key = l2norm(key, dim=-1, eps=1e-6)
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
+ ]
+
+ batch_size, num_heads, sequence_length, k_head_dim = key.shape
+ v_head_dim = value.shape[-1]
+ scale = 1 / (query.shape[-1] ** 0.5)
+ query = query * scale
+
+ core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
+ last_recurrent_state = (
+ torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
+ if initial_state is None
+ else initial_state.to(value)
+ )
+
+ for i in range(sequence_length):
+ q_t = query[:, :, i]
+ k_t = key[:, :, i]
+ v_t = value[:, :, i]
+ g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
+ beta_t = beta[:, :, i].unsqueeze(-1)
+
+ last_recurrent_state = last_recurrent_state * g_t
+ kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
+ delta = (v_t - kv_mem) * beta_t
+ last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
+ core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
+
+ if not output_final_state:
+ last_recurrent_state = None
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
+ return core_attn_out, last_recurrent_state
+
+
+class Qwen3_5MoeGatedDeltaNet(nn.Module):
+ def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_v_heads = config.linear_num_value_heads
+ self.num_k_heads = config.linear_num_key_heads
+ self.head_k_dim = config.linear_key_head_dim
+ self.head_v_dim = config.linear_value_head_dim
+ self.key_dim = self.head_k_dim * self.num_k_heads
+ self.value_dim = self.head_v_dim * self.num_v_heads
+
+ self.conv_kernel_size = config.linear_conv_kernel_dim
+ self.layer_idx = layer_idx
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
+ self.layer_norm_epsilon = config.rms_norm_eps
+
+ # QKV
+ self.conv_dim = self.key_dim * 2 + self.value_dim
+ self.conv1d = nn.Conv1d(
+ in_channels=self.conv_dim,
+ out_channels=self.conv_dim,
+ bias=False,
+ kernel_size=self.conv_kernel_size,
+ groups=self.conv_dim,
+ padding=self.conv_kernel_size - 1,
+ )
+
+ # time step projection (discretization)
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
+
+ A = torch.empty(self.num_v_heads).uniform_(0, 16)
+ self.A_log = nn.Parameter(torch.log(A))
+
+ self.norm = (
+ Qwen3_5MoeRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
+ if FusedRMSNormGated is None
+ else FusedRMSNormGated(
+ self.head_v_dim,
+ eps=self.layer_norm_epsilon,
+ activation=self.activation,
+ device=torch.cuda.current_device(),
+ dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
+ )
+ )
+
+ self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
+
+ self.causal_conv1d_fn = causal_conv1d_fn
+ self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
+ self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
+ self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
+
+ if not is_fast_path_available:
+ logger.warning_once(
+ "The fast path is not available because one of the required library is not installed. Falling back to "
+ "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
+ " https://github.com/Dao-AILab/causal-conv1d"
+ )
+
+ self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False)
+ self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
+ self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+ self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Qwen3_5MoeDynamicCache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ ):
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
+
+ # Set up dimensions for reshapes later
+ batch_size, seq_len, _ = hidden_states.shape
+
+ use_precomputed_states = (
+ cache_params is not None
+ and cache_params.has_previous_state
+ and seq_len == 1
+ and cache_position is not None
+ )
+
+ # getting projected states from cache if it exists
+ if cache_params is not None:
+ conv_state = cache_params.conv_states[self.layer_idx]
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
+
+ mixed_qkv = self.in_proj_qkv(hidden_states)
+ mixed_qkv = mixed_qkv.transpose(1, 2)
+
+ z = self.in_proj_z(hidden_states)
+ z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
+
+ b = self.in_proj_b(hidden_states)
+ a = self.in_proj_a(hidden_states)
+
+ if use_precomputed_states:
+ # 2. Convolution sequence transformation
+ # NOTE: the conv state is updated in `causal_conv1d_update`
+ mixed_qkv = self.causal_conv1d_update(
+ mixed_qkv,
+ conv_state,
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.activation,
+ )
+ else:
+ if cache_params is not None:
+ conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
+ cache_params.conv_states[self.layer_idx] = conv_state
+ if self.causal_conv1d_fn is not None:
+ mixed_qkv = self.causal_conv1d_fn(
+ x=mixed_qkv,
+ weight=self.conv1d.weight.squeeze(1),
+ bias=self.conv1d.bias,
+ activation=self.activation,
+ seq_idx=None,
+ )
+ else:
+ mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
+
+ mixed_qkv = mixed_qkv.transpose(1, 2)
+ query, key, value = torch.split(
+ mixed_qkv,
+ [
+ self.key_dim,
+ self.key_dim,
+ self.value_dim,
+ ],
+ dim=-1,
+ )
+
+ query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
+ key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
+ value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
+
+ beta = b.sigmoid()
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
+ if self.num_v_heads // self.num_k_heads > 1:
+ query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
+ key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
+
+ if not use_precomputed_states:
+ core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ initial_state=None,
+ output_final_state=cache_params is not None,
+ use_qk_l2norm_in_kernel=True,
+ )
+
+ else:
+ core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
+ query,
+ key,
+ value,
+ g=g,
+ beta=beta,
+ initial_state=recurrent_state,
+ output_final_state=cache_params is not None,
+ use_qk_l2norm_in_kernel=True,
+ )
+
+ # Update cache
+ if cache_params is not None:
+ cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
+
+ # reshape input data into 2D tensor
+ core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
+ z = z.reshape(-1, self.head_v_dim)
+ core_attn_out = self.norm(core_attn_out, z)
+ core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
+
+ output = self.out_proj(core_attn_out)
+ return output
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Removes the interleaving of cos and sin from GLM
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+@use_kernelized_func(apply_rotary_pos_emb)
+class Qwen3_5MoeAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = Qwen3_5MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
+ self.k_norm = Qwen3_5MoeRMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # thus post q_norm does not need reshape
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_values: Cache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states, gate = torch.chunk(
+ self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
+ )
+ gate = gate.reshape(*input_shape, -1)
+
+ query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output * torch.sigmoid(gate)
+
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen3_5MoeMLP(nn.Module):
+ def __init__(self, config: Qwen3_5MoeConfig, intermediate_size: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+@use_experts_implementation
+class Qwen3_5MoeExperts(nn.Module):
+ """Collection of expert weights stored as 3D tensors."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.hidden_dim = config.hidden_size
+ self.intermediate_dim = config.moe_intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ top_k_index: torch.Tensor,
+ top_k_weights: torch.Tensor,
+ ) -> torch.Tensor:
+ final_hidden_states = torch.zeros_like(hidden_states)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+ for expert_idx in expert_hit:
+ expert_idx = expert_idx[0]
+ if expert_idx == self.num_experts:
+ continue
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
+ current_hidden_states = self.act_fn(gate) * up
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
+
+ return final_hidden_states
+
+
+class Qwen3_5MoeTopKRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_experts
+ self.hidden_dim = config.hidden_size
+ self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
+ router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
+ router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
+ router_top_value = router_top_value.to(router_logits.dtype)
+ router_scores = router_top_value
+ return router_logits, router_scores, router_indices
+
+
+class Qwen3_5MoeSparseMoeBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.gate = Qwen3_5MoeTopKRouter(config)
+ self.experts = Qwen3_5MoeExperts(config)
+ self.shared_expert = Qwen3_5MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
+ self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
+ shared_expert_output = self.shared_expert(hidden_states_reshaped)
+ _, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
+ expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
+
+ shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output
+
+ expert_output += shared_expert_output
+ expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim)
+ return expert_output
+
+
+class Qwen3_5MoeRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ # Llama does x.to(float16) * w whilst Qwen3_5Moe is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = output * (1.0 + self.weight.float())
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class Qwen3_5MoeDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_type = config.layer_types[layer_idx]
+ if self.layer_type == "linear_attention":
+ self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx)
+ elif self.layer_type == "full_attention":
+ self.self_attn = Qwen3_5MoeAttention(config, layer_idx)
+ self.mlp = Qwen3_5MoeSparseMoeBlock(config)
+ self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Token Mixer
+ if self.layer_type == "linear_attention":
+ hidden_states = self.linear_attn(
+ hidden_states=hidden_states,
+ cache_params=past_key_values,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ )
+ elif self.layer_type == "full_attention":
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ # For the MoE layers, we need to unpack
+ if isinstance(hidden_states, tuple):
+ hidden_states, _ = hidden_states
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Qwen3_5MoePreTrainedModel(PreTrainedModel):
+ config: Qwen3_5MoeConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen3_5MoeDecoderLayer", "Qwen3_5MoeVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _keys_to_ignore_on_load_unexpected = [r"^mtp.*"]
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(Qwen3_5MoeSparseMoeBlock, index=1),
+ "hidden_states": Qwen3_5MoeDecoderLayer,
+ "attentions": Qwen3_5MoeAttention,
+ }
+ _is_stateful = True
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, Qwen3_5MoeGatedDeltaNet):
+ init.ones_(module.dt_bias)
+ init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_())
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
+ elif isinstance(module, Qwen3_5MoeRMSNorm):
+ init.zeros_(module.weight)
+ elif isinstance(module, Qwen3_5MoeExperts):
+ init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
+ init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, Qwen3_5MoeSparseMoeBlock):
+ init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, Qwen3_5MoeVisionRotaryEmbedding):
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
+ init.copy_(module.inv_freq, inv_freq)
+
+
+class Qwen3_5MoeVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3_5MoeVisionPatchEmbed(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3_5MoeVisionPatchMerger(nn.Module):
+ def __init__(self, config: Qwen3_5MoeVisionConfig, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class Qwen3_5MoeVisionAttention(nn.Module):
+ def __init__(self, config: Qwen3_5MoeVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor | None = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ if is_flash_attention_requested(self.config):
+ # Flash Attention: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen3_5MoeVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen3_5MoeVisionAttention(config=config)
+ self.mlp = Qwen3_5MoeVisionMLP(config=config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor | None = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3_5MoeVisionModel(Qwen3_5MoePreTrainedModel):
+ config: Qwen3_5MoeVisionConfig
+ _no_split_modules = ["Qwen3_5MoeVisionBlock"]
+ _can_record_outputs = {
+ "hidden_states": Qwen3_5MoeVisionBlock,
+ "attentions": Qwen3_5MoeVisionAttention,
+ }
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen3_5MoeVisionPatchEmbed(
+ config=config,
+ )
+
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen3_5MoeVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen3_5MoeVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=False,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.post_init()
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ device = freq_table.device
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device) # block row indices
+ block_cols = torch.arange(merged_w, device=device) # block col indices
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+ device = self.pos_embed.weight.device
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
+ weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
+ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.config.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ @check_model_inputs
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ merged_hidden_states = self.merger(hidden_states)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=hidden_states,
+ pooler_output=merged_hidden_states,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Qwen3_5MoeModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ rope_deltas: torch.LongTensor | None = None
+ router_logits: tuple[torch.FloatTensor] | None = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen3_5Moe causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen3_5MoeCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ rope_deltas: torch.LongTensor | None = None
+ aux_loss: torch.FloatTensor | None = None
+
+
+class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel):
+ def __init__(self, config: Qwen3_5MoeTextConfig):
+ super().__init__(config)
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
+ self.layers = nn.ModuleList(
+ [Qwen3_5MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = Qwen3_5MoeDynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # mrope: the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = position_ids[0]
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+ linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=layer_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return Qwen3_5MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
+ """
+ NOTE: Left-padding is used for linear attention mask.
+ No need for zeroing states when
+ 1. Cached forward
+ 2. Attending to all inputs
+ """
+ linear_attn_mask = attention_mask
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
+ linear_attn_mask = None
+ return linear_attn_mask
+
+
+@auto_docstring
+class Qwen3_5MoeModel(Qwen3_5MoePreTrainedModel):
+ base_model_prefix = "model"
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3_5MoeConfig
+ _no_split_modules = ["Qwen3_5MoeTextDecoderLayer", "Qwen3_5MoeVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen3_5MoeVisionModel._from_config(config.vision_config)
+ self.language_model = Qwen3_5MoeTextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Different from the original implementation, Qwen3_5Moe use timestamps rather than absolute time position ids."""
+
+ # Since we use timestamps to separate videos, like , the video_grid_thw should also be split
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
+ video_grid_thw[:, 0] = 1
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ @can_return_tuple
+ @auto_docstring
+ def get_video_features(
+ self,
+ pixel_values_videos: torch.FloatTensor,
+ video_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ # Same implementation as for images
+ return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ vision_output: BaseModelOutputWithPooling = self.visual(
+ pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs
+ )
+ image_embeds = vision_output.pooler_output
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ vision_output.pooler_output = image_embeds
+
+ return vision_output
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: torch.FloatTensor | None = None,
+ video_features: torch.FloatTensor | None = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None:
+ torch_compilable_check(
+ inputs_embeds[special_video_mask].numel() == video_features.numel(),
+ f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
+ )
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Qwen3_5MoeModelOutputWithPast:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_outputs: BaseModelOutputWithPooling = self.get_image_features(
+ pixel_values, image_grid_thw, return_dict=True
+ )
+ image_embeds = image_outputs.pooler_output
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_outputs: BaseModelOutputWithPooling = self.get_video_features(
+ pixel_values_videos, video_grid_thw, return_dict=True
+ )
+ video_embeds = video_outputs.pooler_output
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if position_ids is None:
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
+ if self.rope_deltas is None or past_key_values_length == 0:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Qwen3_5MoeModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ rope_deltas=self.rope_deltas,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def load_balancing_loss_func(
+ gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
+ num_experts: int | None = None,
+ top_k=2,
+ attention_mask: torch.Tensor | None = None,
+) -> torch.Tensor | int:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+@auto_docstring
+class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+ _tp_plan = {"lm_head": "colwise_gather_output"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ config: Qwen3_5MoeTextConfig
+ _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3_5MoeTextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Qwen3_5MoeDynamicCache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_router_logits: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen3_5MoeForCausalLM
+
+ >>> model = Qwen3_5MoeForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: MoeModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_router_logits=output_router_logits,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+class Qwen3_5MoeForConditionalGeneration(Qwen3_5MoePreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3_5MoeConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3_5MoeModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ @auto_docstring
+ def get_video_features(
+ self,
+ pixel_values_videos: torch.FloatTensor,
+ video_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ return self.model.get_video_features(
+ pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
+ )
+
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Qwen3_5MoeCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration
+
+ >>> model = Qwen3_5MoeForConditionalGeneration.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct", dtype="auto", device_map="auto")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+ },
+ {"type": "text", "text": "Describe this image in short."},
+ ],
+ }
+ ]
+
+ >>> # Preparation for inference
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ )
+ >>> inputs = inputs.to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=128)
+ >>> generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background."
+ ```"""
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ aux_loss = None
+ if kwargs.get("output_router_logits", False):
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits,
+ self.config.text_config.num_experts,
+ self.config.text_config.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.config.text_config.router_aux_loss_coef * aux_loss.to(
+ loss.device
+ ) # make sure to reside in the same device
+
+ return Qwen3_5MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ is_first_iteration=False,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ is_first_iteration=is_first_iteration,
+ **kwargs,
+ )
+
+ # Qwen3_5Moe position_ids are prepared with rope_deltas
+ if position_ids is None:
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
+ vision_positions, rope_deltas = self.model.get_rope_index(
+ model_inputs.get("input_ids", None),
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ attention_mask=attention_mask,
+ )
+ self.model.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ elif "position_ids" in model_inputs:
+ batch_size, seq_length = model_inputs["position_ids"].shape
+ device = model_inputs["position_ids"].device
+ position_ids = torch.arange(seq_length, device=device)
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
+ delta = cache_position[0] + self.model.rope_deltas
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ vision_positions = position_ids + delta.expand_as(position_ids)
+
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
+ text_positions = model_inputs["position_ids"][None, ...]
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
+
+ if not is_first_iteration and use_cache:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: torch.LongTensor | None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: torch.LongTensor | None = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Qwen3_5Moe use timestamps and remove second_per_grid_ts
+ # Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ # video_nums: (batch_size,)
+ # since video_nums is the number of videos in the input dependent on the input_ids(vision_start),
+ # but Qwen3_5Moe append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw
+ if video_grid_thw is not None:
+ cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
+ cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
+ # Find video boundaries in cumulative_frame_counts
+ video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
+ # example: video_boundary_indices = [3, 5] means video_nums = [4, 2]
+ video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = [
+ "Qwen3_5MoeVisionModel",
+ "Qwen3_5MoeTextModel",
+ "Qwen3_5MoeModel",
+ "Qwen3_5MoeForCausalLM",
+ "Qwen3_5MoeForConditionalGeneration",
+ "Qwen3_5MoePreTrainedModel",
+]
diff --git a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py
new file mode 100644
index 000000000000..3369cf363ee9
--- /dev/null
+++ b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py
@@ -0,0 +1,464 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen3.5Moe model."""
+
+import torch
+
+from ... import initialization as init
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPooling
+from ...modeling_rope_utils import RopeParameters
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+from ..qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig
+from ..qwen3_5.modeling_qwen3_5 import (
+ Qwen3_5GatedDeltaNet,
+ Qwen3_5MLP,
+ Qwen3_5Model,
+ Qwen3_5TextModel,
+ Qwen3_5TextRotaryEmbedding,
+ Qwen3_5VisionModel,
+ Qwen3_5VisionRotaryEmbedding,
+)
+from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig
+from ..qwen3_next.modeling_qwen3_next import (
+ Qwen3NextAttention,
+ Qwen3NextDecoderLayer,
+ Qwen3NextDynamicCache,
+ Qwen3NextExperts,
+ Qwen3NextForCausalLM,
+ Qwen3NextPreTrainedModel,
+ Qwen3NextRMSNorm,
+ Qwen3NextSparseMoeBlock,
+)
+from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig
+from ..qwen3_vl_moe.modeling_qwen3_vl_moe import (
+ Qwen3VLMoeCausalLMOutputWithPast,
+ Qwen3VLMoeForConditionalGeneration,
+ Qwen3VLMoeModelOutputWithPast,
+ Qwen3VLMoeTextTopKRouter,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3_5MoeTextConfig(Qwen3NextConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3_5MoeTextModel`]. It is used to instantiate a
+ Qwen3.5-MoE model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of
+ Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 248320):
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
+ `inputs_ids`.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 40):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 256):
+ Projection weights dimension in multi-head attention.
+ linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
+ Kernel size of the convolution used in linear attention layers.
+ linear_key_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of each key head in linear attention.
+ linear_value_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of each value head in linear attention.
+ linear_num_key_heads (`int`, *optional*, defaults to 16):
+ Number of key heads used in linear attention layers.
+ linear_num_value_heads (`int`, *optional*, defaults to 32):
+ Number of value heads used in linear attention layers.
+ moe_intermediate_size (`int`, *optional*, defaults to 512):
+ Intermediate size of the routed expert.
+ shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
+ Intermediate size of the shared expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ Number of selected experts.
+ num_experts (`int`, *optional*, defaults to 256):
+ Number of routed experts.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+ layer_types (`list[str]`, *optional*):
+ Types of each layer (attention or linear).
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*):
+ End of stream token id.
+
+ ```python
+ >>> from transformers import Qwen3_5MoeTextModel, Qwen3_5MoeTextConfig
+
+ >>> # Initializing a Qwen3.5-MoE style configuration
+ >>> configuration = Qwen3_5MoeTextConfig()
+
+ >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration
+ >>> model = Qwen3_5MoeTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "qwen3_5_moe_text"
+ base_config_key = "text_config"
+
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
+ "layers.*.mlp.experts.down_proj": "rowwise",
+ "layers.*.mlp.shared_expert.gate_proj": "colwise",
+ "layers.*.mlp.shared_expert.up_proj": "colwise",
+ "layers.*.mlp.shared_expert.down_proj": "rowwise",
+ }
+
+ def __init__(
+ self,
+ vocab_size=248320,
+ hidden_size=2048,
+ num_hidden_layers=40,
+ num_attention_heads=16,
+ num_key_value_heads=2,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=256,
+ linear_conv_kernel_dim=4,
+ linear_key_head_dim=128,
+ linear_value_head_dim=128,
+ linear_num_key_heads=16,
+ linear_num_value_heads=32,
+ moe_intermediate_size=512,
+ shared_expert_intermediate_size=512,
+ num_experts_per_tok=8,
+ num_experts=256,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ layer_types=None,
+ pad_token_id: int | None = None,
+ bos_token_id: int | None = None,
+ eos_token_id: int | None = None,
+ **kwargs,
+ ):
+ kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"}
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ del self.intermediate_size
+ del self.decoder_sparse_step
+ del self.norm_topk_prob
+ del self.mlp_only_layers
+
+
+class Qwen3_5MoeVisionConfig(Qwen3_5VisionConfig):
+ pass
+
+
+class Qwen3_5MoeConfig(Qwen3VLConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3_5MoeModel`]. It is used to instantiate a
+ Qwen3.5-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 248056):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 248057):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 248053):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 248054):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeConfig
+
+ >>> # Initializing a Qwen3.5-MoE style configuration
+ >>> configuration = Qwen3_5MoeConfig()
+
+ >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration
+ >>> model = Qwen3_5MoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_5_moe"
+ sub_configs = {"vision_config": Qwen3_5MoeVisionConfig, "text_config": Qwen3_5MoeTextConfig}
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=248056,
+ video_token_id=248057,
+ vision_start_token_id=248053,
+ vision_end_token_id=248054,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ super().__init__(
+ text_config=text_config,
+ vision_config=vision_config,
+ image_token_id=image_token_id,
+ video_token_id=video_token_id,
+ vision_start_token_id=vision_start_token_id,
+ vision_end_token_id=vision_end_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class Qwen3_5MoeVisionRotaryEmbedding(Qwen3_5VisionRotaryEmbedding):
+ pass
+
+
+class Qwen3_5MoeTextRotaryEmbedding(Qwen3_5TextRotaryEmbedding):
+ pass
+
+
+class Qwen3_5MoeDynamicCache(Qwen3NextDynamicCache):
+ pass
+
+
+class Qwen3_5MoeGatedDeltaNet(Qwen3_5GatedDeltaNet):
+ pass
+
+
+class Qwen3_5MoeAttention(Qwen3NextAttention):
+ pass
+
+
+class Qwen3_5MoeMLP(Qwen3_5MLP):
+ pass
+
+
+class Qwen3_5MoeExperts(Qwen3NextExperts):
+ pass
+
+
+class Qwen3_5MoeTopKRouter(Qwen3VLMoeTextTopKRouter):
+ pass
+
+
+class Qwen3_5MoeSparseMoeBlock(Qwen3NextSparseMoeBlock):
+ pass
+
+
+class Qwen3_5MoeRMSNorm(Qwen3NextRMSNorm):
+ pass
+
+
+class Qwen3_5MoeDecoderLayer(Qwen3NextDecoderLayer):
+ def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int):
+ GradientCheckpointingLayer.__init__(self)
+ self.hidden_size = config.hidden_size
+ self.layer_type = config.layer_types[layer_idx]
+ if self.layer_type == "linear_attention":
+ self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx)
+ elif self.layer_type == "full_attention":
+ self.self_attn = Qwen3_5MoeAttention(config, layer_idx)
+ self.mlp = Qwen3_5MoeSparseMoeBlock(config)
+ self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+
+class Qwen3_5MoePreTrainedModel(Qwen3NextPreTrainedModel):
+ _no_split_modules = ["Qwen3_5MoeDecoderLayer", "Qwen3_5MoeVisionBlock"]
+
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, Qwen3_5MoeGatedDeltaNet):
+ init.ones_(module.dt_bias)
+ init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_())
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
+ elif isinstance(module, Qwen3_5MoeRMSNorm):
+ init.zeros_(module.weight)
+ elif isinstance(module, Qwen3_5MoeExperts):
+ init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
+ init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, Qwen3_5MoeSparseMoeBlock):
+ init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, Qwen3_5MoeVisionRotaryEmbedding):
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
+ init.copy_(module.inv_freq, inv_freq)
+
+
+class Qwen3_5MoeVisionModel(Qwen3_5VisionModel):
+ pass
+
+
+class Qwen3_5MoeModelOutputWithPast(Qwen3VLMoeModelOutputWithPast):
+ router_logits: tuple[torch.FloatTensor] | None = None
+
+
+class Qwen3_5MoeCausalLMOutputWithPast(Qwen3VLMoeCausalLMOutputWithPast):
+ pass
+
+
+class Qwen3_5MoeTextModel(Qwen3_5TextModel):
+ pass
+
+
+class Qwen3_5MoeModel(Qwen3_5Model):
+ pass
+
+
+class Qwen3_5MoeForCausalLM(Qwen3NextForCausalLM):
+ config: Qwen3_5MoeTextConfig
+ _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3_5MoeTextModel(config)
+
+
+class Qwen3_5MoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
+ def forward(self, **super_kwargs):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration
+
+ >>> model = Qwen3_5MoeForConditionalGeneration.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct", dtype="auto", device_map="auto")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+ },
+ {"type": "text", "text": "Describe this image in short."},
+ ],
+ }
+ ]
+
+ >>> # Preparation for inference
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ )
+ >>> inputs = inputs.to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=128)
+ >>> generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background."
+ ```"""
+ super().forward(**super_kwargs)
+
+ def get_video_features(
+ self,
+ **super_kwargs,
+ ) -> tuple | BaseModelOutputWithPooling:
+ return super().get_video_features(**super_kwargs)
+
+ def get_image_features(
+ self,
+ **super_kwargs,
+ ) -> tuple | BaseModelOutputWithPooling:
+ return super().get_image_features(**super_kwargs)
+
+
+__all__ = [
+ "Qwen3_5MoeConfig",
+ "Qwen3_5MoeTextConfig",
+ "Qwen3_5MoeVisionModel",
+ "Qwen3_5MoeTextModel",
+ "Qwen3_5MoeModel",
+ "Qwen3_5MoeForCausalLM",
+ "Qwen3_5MoeForConditionalGeneration",
+ "Qwen3_5MoePreTrainedModel",
+]
diff --git a/tests/models/qwen3_5/__init__.py b/tests/models/qwen3_5/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py
new file mode 100644
index 000000000000..da025439acb4
--- /dev/null
+++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py
@@ -0,0 +1,402 @@
+# Copyright 2026 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Qwen3.5 model."""
+
+import unittest
+
+from transformers import is_torch_available
+from transformers.testing_utils import (
+ require_torch,
+ torch_device,
+)
+
+from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ floats_tensor,
+ ids_tensor,
+)
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ Qwen3_5Config,
+ Qwen3_5ForCausalLM,
+ Qwen3_5ForConditionalGeneration,
+ Qwen3_5Model,
+ Qwen3_5TextConfig,
+ Qwen3_5TextModel,
+ )
+ from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache
+
+
+class Qwen3_5TextModelTester(CausalLMModelTester):
+ if is_torch_available():
+ base_model_class = Qwen3_5TextModel
+ causal_lm_class = Qwen3_5ForCausalLM
+
+ def __init__(self, parent):
+ super().__init__(parent=parent)
+ self.layer_types = ["full_attention", "linear_attention"]
+ self.linear_conv_kernel_dim = 2
+ self.linear_key_head_dim = 16
+ self.linear_value_head_dim = 16
+ self.linear_num_key_heads = 4
+ self.linear_num_value_heads = 8
+
+
+@require_torch
+class Qwen3_5TextModelTest(CausalLMModelTest, unittest.TestCase):
+ model_tester_class = Qwen3_5TextModelTester
+ config_class = Qwen3_5TextConfig
+ model_split_percents = [0.5, 0.8, 0.9]
+
+ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
+ "Qwen3.5 has a special Cache as it alternates with gated deltanet layers"
+ self.assertIsInstance(past_key_values, Qwen3_5DynamicCache)
+
+ # (batch, kv heads, seq_length, head_dim)
+ num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ expected_shape = (batch_size, num_heads, seq_length, head_dim)
+
+ attention_layer_indices = past_key_values.transformer_layers
+ self.assertListEqual(
+ [past_key_values.key_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+ self.assertListEqual(
+ [past_key_values.value_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+
+ def _check_caches_are_equal(self, cache1, cache2):
+ "Qwen3.5 has a special Cache as it alternates with gated deltanet layers"
+ if not len(cache1) == len(cache2):
+ raise ValueError("Both caches do not have the same number of layers.")
+
+ num_layers = len(cache1)
+ for idx in range(num_layers):
+ if cache1.key_cache[idx] is not None:
+ torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx])
+ torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx])
+
+ def test_attention_outputs(self):
+ "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers."
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ seq_len = getattr(self.model_tester, "seq_length", None)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class._from_config(config, attn_implementation="eager")
+ config = model.config
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
+ self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ self_attentions = outputs.attentions
+
+ self.assertEqual(out_len + 1, len(outputs))
+ self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types))
+ self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
+
+ @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
+ @unittest.skip("Intentionally not reversable (no changes) as only load time within a VLM depends on this")
+ def test_reverse_loading_mapping(self, check_keys_were_modified=True):
+ pass
+
+
+class Qwen3_5VisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ seq_length=7,
+ num_channels=3,
+ ignore_index=-100,
+ image_size=16,
+ text_config={
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "pad_token_id": 2,
+ "hidden_act": "silu",
+ "head_dim": 8,
+ "hidden_size": 32,
+ "vocab_size": 99,
+ "intermediate_size": 37,
+ "max_position_embeddings": 512,
+ "model_type": "qwen3_vl",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 2,
+ "layer_types": ["full_attention", "linear_attention"],
+ "num_key_value_heads": 2,
+ "rope_theta": 10000,
+ "tie_word_embeddings": True,
+ "rope_parameters": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True},
+ "linear_conv_kernel_dim": 2,
+ "linear_key_head_dim": 16,
+ "linear_value_head_dim": 16,
+ "linear_num_key_heads": 4,
+ "linear_num_value_heads": 8,
+ },
+ vision_config={
+ "depth": 2,
+ "in_chans": 3,
+ "hidden_act": "gelu_pytorch_tanh",
+ "intermediate_size": 32,
+ "out_hidden_size": 32,
+ "hidden_size": 32,
+ "num_heads": 4,
+ "patch_size": 16,
+ "spatial_merge_size": 1,
+ "temporal_patch_size": 2,
+ "num_position_embeddings": 16,
+ },
+ image_token_id=3,
+ video_token_id=4,
+ vision_start_token_id=5,
+ vision_end_token_id=6,
+ tie_word_embeddings=True,
+ is_training=True,
+ ):
+ self.parent = parent
+ self.ignore_index = ignore_index
+ self.is_training = is_training
+
+ self.vision_config = vision_config
+ self.text_config = text_config
+
+ self.vocab_size = text_config["vocab_size"]
+ self.bos_token_id = text_config["bos_token_id"]
+ self.eos_token_id = text_config["eos_token_id"]
+ self.pad_token_id = text_config["pad_token_id"]
+ self.head_dim = text_config["head_dim"]
+ self.hidden_size = text_config["hidden_size"]
+ self.intermediate_size = text_config["intermediate_size"]
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.num_key_value_heads = text_config["num_key_value_heads"]
+ self.rope_theta = text_config["rope_theta"]
+ self.rope_parameters = text_config["rope_parameters"]
+ self.hidden_act = text_config["hidden_act"]
+ self.max_position_embeddings = text_config["max_position_embeddings"]
+ self.model_type = text_config["model_type"]
+
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.num_image_tokens = 32
+ self.seq_length = seq_length + self.num_image_tokens
+
+ def get_config(self):
+ return Qwen3_5Config(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ image_token_id=self.image_token_id,
+ video_token_id=self.video_token_id,
+ vision_start_token_id=self.vision_start_token_id,
+ vision_end_token_id=self.vision_end_token_id,
+ tie_word_embeddings=self.tie_word_embeddings,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ patch_size = config.vision_config.patch_size
+ temporal_patch_size = config.vision_config.temporal_patch_size
+ pixel_values = floats_tensor(
+ [
+ self.batch_size * (self.image_size**2) // (patch_size**2),
+ self.num_channels * (patch_size**2) * temporal_patch_size,
+ ]
+ )
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ input_ids[:, -1] = self.pad_token_id
+ input_ids[input_ids == self.video_token_id] = self.pad_token_id
+ input_ids[input_ids == self.image_token_id] = self.pad_token_id
+ input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
+ input_ids[:, self.num_image_tokens] = self.image_token_id
+ input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class Qwen3_5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ """
+ Model tester for `Qwen3_5ForConditionalGeneration`.
+ """
+
+ all_model_classes = (
+ (
+ Qwen3_5Model,
+ Qwen3_5ForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
+ model_split_percents = [0.5, 0.8, 0.9]
+
+ def setUp(self):
+ self.model_tester = Qwen3_5VisionText2TextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Qwen3_5Config, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
+ "Qwen3.5 has a special Cache as it alternates with gated deltanet layers"
+ self.assertIsInstance(past_key_values, Qwen3_5DynamicCache)
+
+ # (batch, kv heads, seq_length, head_dim)
+ num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ expected_shape = (batch_size, num_heads, seq_length, head_dim)
+
+ attention_layer_indices = past_key_values.transformer_layers
+ self.assertListEqual(
+ [past_key_values.key_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+ self.assertListEqual(
+ [past_key_values.value_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+
+ def _check_caches_are_equal(self, cache1, cache2):
+ "Qwen3.5 has a special Cache as it alternates with gated deltanet layers"
+ if not len(cache1) == len(cache2):
+ raise ValueError("Both caches do not have the same number of layers.")
+
+ num_layers = len(cache1)
+ for idx in range(num_layers):
+ if cache1.key_cache[idx] is not None:
+ torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx])
+ torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx])
+
+ def test_attention_outputs(self):
+ "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers."
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ seq_len = getattr(self.model_tester, "seq_length", None)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class._from_config(config, attn_implementation="eager")
+ config = model.config
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(
+ len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types)
+ )
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.text_config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(
+ len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types)
+ )
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len]
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ self_attentions = outputs.attentions
+
+ self.assertEqual(out_len + 1, len(outputs))
+ self.assertEqual(
+ len(self_attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types)
+ )
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len]
+ )
+
+ @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
diff --git a/tests/models/qwen3_5_moe/__init__.py b/tests/models/qwen3_5_moe/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py
new file mode 100644
index 000000000000..edd88860b596
--- /dev/null
+++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py
@@ -0,0 +1,491 @@
+# Copyright 2026 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Qwen3.5 model."""
+
+import copy
+import os
+import re
+import tempfile
+import unittest
+
+from safetensors.torch import load_file
+
+from transformers import is_torch_available
+from transformers.conversion_mapping import get_model_conversion_mapping
+from transformers.core_model_loading import WeightRenaming, process_target_pattern
+from transformers.testing_utils import (
+ require_torch,
+ torch_device,
+)
+
+from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ compare_state_dicts,
+ floats_tensor,
+ ids_tensor,
+)
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ Qwen3_5MoeConfig,
+ Qwen3_5MoeForCausalLM,
+ Qwen3_5MoeForConditionalGeneration,
+ Qwen3_5MoeModel,
+ Qwen3_5MoeTextConfig,
+ Qwen3_5MoeTextModel,
+ )
+ from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeDynamicCache
+
+
+class Qwen3_5MoeTextModelTester(CausalLMModelTester):
+ if is_torch_available():
+ base_model_class = Qwen3_5MoeTextModel
+ causal_lm_class = Qwen3_5MoeForCausalLM
+
+ def __init__(self, parent):
+ super().__init__(parent=parent)
+ self.layer_types = ["full_attention", "linear_attention"]
+ self.linear_conv_kernel_dim = 2
+ self.linear_key_head_dim = 16
+ self.linear_value_head_dim = 16
+ self.linear_num_key_heads = 4
+ self.linear_num_value_heads = 8
+
+
+@require_torch
+class Qwen3_5MoeTextModelTest(CausalLMModelTest, unittest.TestCase):
+ model_tester_class = Qwen3_5MoeTextModelTester
+ config_class = Qwen3_5MoeTextConfig
+
+ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
+ "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers"
+ self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache)
+
+ # (batch, kv heads, seq_length, head_dim)
+ num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ expected_shape = (batch_size, num_heads, seq_length, head_dim)
+
+ attention_layer_indices = past_key_values.transformer_layers
+ self.assertListEqual(
+ [past_key_values.key_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+ self.assertListEqual(
+ [past_key_values.value_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+
+ def _check_caches_are_equal(self, cache1, cache2):
+ "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers"
+ if not len(cache1) == len(cache2):
+ raise ValueError("Both caches do not have the same number of layers.")
+
+ num_layers = len(cache1)
+ for idx in range(num_layers):
+ if cache1.key_cache[idx] is not None:
+ torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx])
+ torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx])
+
+ def test_attention_outputs(self):
+ "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers."
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ seq_len = getattr(self.model_tester, "seq_length", None)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class._from_config(config, attn_implementation="eager")
+ config = model.config
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types))
+ self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ self_attentions = outputs.attentions
+
+ self.assertEqual(out_len + 1, len(outputs))
+ self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types))
+ self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len])
+
+ def test_reverse_loading_mapping(self, check_keys_were_modified=True):
+ """
+ Overwritten to check for the moe portion but ignore the prefix as it results into a noop
+ (except we have a VLM struct initially)
+ """
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at
+ # lest one MoE layer here to check the mapping
+ config_to_set = config.get_text_config(decoder=True)
+ config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE
+ config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5...
+ config_to_set.mlp_only_layers = [0] # same but for qwens
+ config_to_set.num_dense_layers = 1 # lfm2_moe
+
+ for model_class in self.all_model_classes:
+ # Each individual model is a subtest
+ with self.subTest(model_class.__name__):
+ model = model_class(copy.deepcopy(config))
+ # Skip if no conversions
+ conversions = get_model_conversion_mapping(model, add_legacy=False)
+ if len(conversions) == 0:
+ self.skipTest("No conversion found for this model")
+
+ # Find the model keys, so the targets according to the conversions
+ model_keys = list(model.state_dict().keys())
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ # Serialize with reverse mapping
+ model.save_pretrained(tmpdirname)
+ state_dict = load_file(os.path.join(tmpdirname, "model.safetensors"))
+ # Get all the serialized keys that we just saved according to the reverse mapping
+ serialized_keys = list(state_dict.keys())
+
+ if check_keys_were_modified:
+ # They should be different, otherwise we did not perform any mapping
+ self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!")
+
+ # Check that for each conversion entry, we at least map to one key
+ for conversion in conversions:
+ for source_pattern in conversion.source_patterns:
+ # Sometimes the mappings specify keys that are tied, so absent from the saved state dict
+ if isinstance(conversion, WeightRenaming):
+ # We need to revert the target pattern to make it compatible with regex search
+ target_pattern_reversed = conversion.target_patterns[0]
+ captured_group = process_target_pattern(source_pattern)[1]
+ if captured_group:
+ target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group)
+ if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()):
+ continue
+ num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys)
+
+ # Key change: special case to load causal lm within vlm
+ if source_pattern == "^model.language_model":
+ continue
+
+ self.assertTrue(
+ num_matches > 0,
+ f"`{source_pattern}` in `{conversion}` did not match any of the source keys. "
+ "This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly",
+ )
+
+ # If everything is still good at this point, let's test that we perform the same operations both when
+ # reverting ops from `from_pretrained` and from `__init__`
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ # The model was instantiated from __init__ before being saved
+ model.save_pretrained(tmpdirname)
+ state_dict_saved_from_init = load_file(os.path.join(tmpdirname, "model.safetensors"))
+
+ # Now reload it
+ model_reloaded = model_class.from_pretrained(tmpdirname)
+
+ # Make sure both loaded state_dict are identical
+ self.assertTrue(compare_state_dicts(model_reloaded.state_dict(), model.state_dict()))
+
+ # The model was instantiated from `from_pretrained` before being saved
+ model_reloaded.save_pretrained(tmpdirname)
+ state_dict_saved_from_pretrained = load_file(os.path.join(tmpdirname, "model.safetensors"))
+
+ # Make sure both saved state_dict are identical
+ self.assertTrue(compare_state_dicts(state_dict_saved_from_init, state_dict_saved_from_pretrained))
+
+ @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
+
+class Qwen3_5MoeVisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ seq_length=7,
+ num_channels=3,
+ ignore_index=-100,
+ image_size=16,
+ text_config={
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "pad_token_id": 2,
+ "hidden_act": "silu",
+ "head_dim": 8,
+ "hidden_size": 32,
+ "vocab_size": 99,
+ "intermediate_size": 37,
+ "max_position_embeddings": 512,
+ "model_type": "qwen3_vl",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 2,
+ "layer_types": ["full_attention", "linear_attention"],
+ "num_key_value_heads": 2,
+ "rope_theta": 10000,
+ "tie_word_embeddings": True,
+ "rope_parameters": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True},
+ "linear_conv_kernel_dim": 2,
+ "linear_key_head_dim": 16,
+ "linear_value_head_dim": 16,
+ "linear_num_key_heads": 4,
+ "linear_num_value_heads": 8,
+ "moe_intermediate_size": 16,
+ "shared_expert_intermediate_size": 36,
+ "num_experts_per_tok": 2,
+ "num_experts": 8,
+ },
+ vision_config={
+ "depth": 2,
+ "in_chans": 3,
+ "hidden_act": "gelu_pytorch_tanh",
+ "intermediate_size": 32,
+ "out_hidden_size": 32,
+ "hidden_size": 32,
+ "num_heads": 4,
+ "patch_size": 16,
+ "spatial_merge_size": 1,
+ "temporal_patch_size": 2,
+ "num_position_embeddings": 16,
+ },
+ image_token_id=3,
+ video_token_id=4,
+ vision_start_token_id=5,
+ vision_end_token_id=6,
+ tie_word_embeddings=True,
+ is_training=True,
+ ):
+ self.parent = parent
+ self.ignore_index = ignore_index
+ self.is_training = is_training
+
+ self.vision_config = vision_config
+ self.text_config = text_config
+
+ self.vocab_size = text_config["vocab_size"]
+ self.bos_token_id = text_config["bos_token_id"]
+ self.eos_token_id = text_config["eos_token_id"]
+ self.pad_token_id = text_config["pad_token_id"]
+ self.head_dim = text_config["head_dim"]
+ self.hidden_size = text_config["hidden_size"]
+ self.intermediate_size = text_config["intermediate_size"]
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.num_key_value_heads = text_config["num_key_value_heads"]
+ self.rope_theta = text_config["rope_theta"]
+ self.rope_parameters = text_config["rope_parameters"]
+ self.hidden_act = text_config["hidden_act"]
+ self.max_position_embeddings = text_config["max_position_embeddings"]
+ self.model_type = text_config["model_type"]
+
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.num_image_tokens = 32
+ self.seq_length = seq_length + self.num_image_tokens
+
+ def get_config(self):
+ return Qwen3_5MoeConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ image_token_id=self.image_token_id,
+ video_token_id=self.video_token_id,
+ vision_start_token_id=self.vision_start_token_id,
+ vision_end_token_id=self.vision_end_token_id,
+ tie_word_embeddings=self.tie_word_embeddings,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ patch_size = config.vision_config.patch_size
+ temporal_patch_size = config.vision_config.temporal_patch_size
+ pixel_values = floats_tensor(
+ [
+ self.batch_size * (self.image_size**2) // (patch_size**2),
+ self.num_channels * (patch_size**2) * temporal_patch_size,
+ ]
+ )
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ input_ids[:, -1] = self.pad_token_id
+ input_ids[input_ids == self.video_token_id] = self.pad_token_id
+ input_ids[input_ids == self.image_token_id] = self.pad_token_id
+ input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
+ input_ids[:, self.num_image_tokens] = self.image_token_id
+ input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class Qwen3_5MoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ """
+ Model tester for `Qwen3_5MoeForConditionalGeneration`.
+ """
+
+ all_model_classes = (
+ (
+ Qwen3_5MoeModel,
+ Qwen3_5MoeForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
+
+ def setUp(self):
+ self.model_tester = Qwen3_5MoeVisionText2TextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Qwen3_5MoeConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
+ "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers"
+ self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache)
+
+ # (batch, kv heads, seq_length, head_dim)
+ num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ expected_shape = (batch_size, num_heads, seq_length, head_dim)
+
+ attention_layer_indices = past_key_values.transformer_layers
+ self.assertListEqual(
+ [past_key_values.key_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+ self.assertListEqual(
+ [past_key_values.value_cache[idx].shape for idx in attention_layer_indices],
+ [expected_shape] * len(attention_layer_indices),
+ )
+
+ def _check_caches_are_equal(self, cache1, cache2):
+ "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers"
+ if not len(cache1) == len(cache2):
+ raise ValueError("Both caches do not have the same number of layers.")
+
+ num_layers = len(cache1)
+ for idx in range(num_layers):
+ if cache1.key_cache[idx] is not None:
+ torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx])
+ torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx])
+
+ def test_attention_outputs(self):
+ "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers."
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ seq_len = getattr(self.model_tester, "seq_length", None)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class._from_config(config, attn_implementation="eager")
+ config = model.config
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(
+ len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types)
+ )
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.text_config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(
+ len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types)
+ )
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len]
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ self_attentions = outputs.attentions
+
+ self.assertEqual(out_len + 1, len(outputs))
+ self.assertEqual(
+ len(self_attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types)
+ )
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len]
+ )
+
+ @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
diff --git a/utils/check_repo.py b/utils/check_repo.py
index fb5e54097bc4..598ae9e90a36 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -72,6 +72,8 @@
"Qwen2_5_VisionTransformerPretrainedModel",
"Qwen3VLVisionModel",
"Qwen3VLMoeVisionModel",
+ "Qwen3_5VisionModel",
+ "Qwen3_5MoeVisionModel",
"SwitchTransformersStack",
"SiglipTextTransformer",
"Siglip2TextTransformer",
@@ -176,6 +178,8 @@
"Qwen3VLMoeModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLMoeForConditionalGeneration.
"Qwen3VLTextModel", # Building part of bigger (tested) model.
"Qwen3VLMoeTextModel", # Building part of bigger (tested) model.
+ "Qwen3_5TextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5ForConditionalGeneration.
+ "Qwen3_5MoeTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5MoeForConditionalGeneration.
"Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest
"Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
"Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
diff --git a/utils/models_to_deprecate.py b/utils/models_to_deprecate.py
index 9d6c38cda388..ed1bc84a376d 100644
--- a/utils/models_to_deprecate.py
+++ b/utils/models_to_deprecate.py
@@ -114,6 +114,8 @@
"qwen2_vl": ["qwen2_vl_text"],
"qwen3_vl_moe": ["qwen3_vl_moe_text"],
"qwen3_vl": ["qwen3_vl_text"],
+ "qwen3_5": ["qwen3_5text"],
+ "qwen3_5_moe": ["qwen3_5_moe_text"],
"rt_detr": ["rt_detr_resnet"],
"sam2": ["sam2_hiera_det_model", "sam2_vision_model"],
"sam": ["sam_hq_vision_model", "sam_vision_model"],